From ed0eee689800fffcf2a324d99ac4ef0c84a6f843 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 9 Aug 2019 04:48:29 -0700 Subject: [PATCH] BlockLSTM and LSTMBlockCell are now generic wrt to the gate layout The default layout is still ICFO, but V2 version of the op will use IFCO to match CuDNN-RNN. PiperOrigin-RevId: 262538325 --- tensorflow/core/kernels/rnn/lstm_ops.cc | 192 +++++++++--------- tensorflow/core/kernels/rnn/lstm_ops.h | 30 ++- .../core/kernels/rnn/lstm_ops_gpu.cu.cc | 152 +++++++------- 3 files changed, 201 insertions(+), 173 deletions(-) diff --git a/tensorflow/core/kernels/rnn/lstm_ops.cc b/tensorflow/core/kernels/rnn/lstm_ops.cc index e9f15d278a9..aaaa168c58e 100644 --- a/tensorflow/core/kernels/rnn/lstm_ops.cc +++ b/tensorflow/core/kernels/rnn/lstm_ops.cc @@ -41,7 +41,7 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { -template +template void LSTMBlockCellFpropWithEigen( const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, const float cell_clip, bool use_peephole, @@ -83,18 +83,21 @@ void LSTMBlockCellFpropWithEigen( // Cell input. ci.device(d) = - gates.slice(cell.gates_c_offsets(), cell.cell_extents()).tanh(); + gates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents()) + .tanh(); // Forget gate (w/ bias). if (use_peephole) { auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); - f.device(d) = (gates.slice(cell.gates_f_offsets(), cell.cell_extents()) + - f.constant(T(forget_bias)) + f_peep) - .sigmoid(); + f.device(d) = + (gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) + + f.constant(T(forget_bias)) + f_peep) + .sigmoid(); } else { - f.device(d) = (gates.slice(cell.gates_f_offsets(), cell.cell_extents()) + - f.constant(T(forget_bias))) - .sigmoid(); + f.device(d) = + (gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) + + f.constant(T(forget_bias))) + .sigmoid(); } // cs = ci .* i + f .* cs_prev @@ -123,7 +126,7 @@ void LSTMBlockCellFpropWithEigen( h.device(d) = o * co; } -template +template void LSTMBlockCellBpropWithEigen( const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, bool use_peephole, typename TTypes::ConstMatrix x, @@ -164,8 +167,10 @@ void LSTMBlockCellBpropWithEigen( di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; dgates.slice(cell.gates_i_offsets(), cell.cell_extents()).device(d) = di; - dgates.slice(cell.gates_c_offsets(), cell.cell_extents()).device(d) = dci; - dgates.slice(cell.gates_f_offsets(), cell.cell_extents()).device(d) = df; + dgates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents()) + .device(d) = dci; + dgates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) + .device(d) = df; dgates.slice(cell.gates_o_offsets(), cell.cell_extents()).device(d) = do_; cs_prev_grad.device(d) = dcs * f; @@ -179,54 +184,58 @@ void LSTMBlockCellBpropWithEigen( } } -#define DECLARE_CPU_FBPROP(T) \ - template <> \ - void LSTMBlockCellFprop::operator()( \ - OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \ - const float cell_clip, bool use_peephole, \ - typename TTypes::ConstMatrix x, \ - typename TTypes::ConstMatrix cs_prev, \ - typename TTypes::ConstMatrix h_prev, \ - typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ - typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ - typename TTypes::ConstVec b, typename TTypes::Matrix xh, \ - typename TTypes::Matrix i, typename TTypes::Matrix cs, \ - typename TTypes::Matrix f, typename TTypes::Matrix o, \ - typename TTypes::Matrix ci, typename TTypes::Matrix co, \ - typename TTypes::Matrix gates, typename TTypes::Matrix h) { \ - LSTMBlockCellFpropWithEigen( \ - *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \ - h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h); \ - } \ - template <> \ - void LSTMBlockCellBprop::operator()( \ - OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \ - typename TTypes::ConstMatrix x, \ - typename TTypes::ConstMatrix cs_prev, \ - typename TTypes::ConstMatrix h_prev, \ - typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ - typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ - typename TTypes::ConstVec b, typename TTypes::ConstMatrix i, \ - typename TTypes::ConstMatrix cs, typename TTypes::ConstMatrix f, \ - typename TTypes::ConstMatrix o, typename TTypes::ConstMatrix ci, \ - typename TTypes::ConstMatrix co, \ - typename TTypes::ConstMatrix cs_grad, \ - typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, \ - typename TTypes::Matrix dcs, typename TTypes::Matrix dci, \ - typename TTypes::Matrix df, typename TTypes::Matrix di, \ - typename TTypes::Matrix dgates, \ - typename TTypes::Matrix cs_prev_grad, \ - typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ - typename TTypes::Vec wco_grad) { \ - LSTMBlockCellBpropWithEigen( \ - *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \ - i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dgates, \ - cs_prev_grad, wci_grad, wcf_grad, wco_grad); \ - } \ - template struct LSTMBlockCellFprop; \ - template struct LSTMBlockCellBprop; +#define DECLARE_CPU_FBPROP(T, GATE_LAYOUT) \ + template <> \ + void LSTMBlockCellFprop:: \ + operator()( \ + OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \ + const float cell_clip, bool use_peephole, \ + typename TTypes::ConstMatrix x, \ + typename TTypes::ConstMatrix cs_prev, \ + typename TTypes::ConstMatrix h_prev, \ + typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ + typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ + typename TTypes::ConstVec b, typename TTypes::Matrix xh, \ + typename TTypes::Matrix i, typename TTypes::Matrix cs, \ + typename TTypes::Matrix f, typename TTypes::Matrix o, \ + typename TTypes::Matrix ci, typename TTypes::Matrix co, \ + typename TTypes::Matrix gates, typename TTypes::Matrix h) { \ + LSTMBlockCellFpropWithEigen( \ + *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \ + h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h); \ + } \ + template <> \ + void LSTMBlockCellBprop:: \ + operator()( \ + OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \ + typename TTypes::ConstMatrix x, \ + typename TTypes::ConstMatrix cs_prev, \ + typename TTypes::ConstMatrix h_prev, \ + typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ + typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ + typename TTypes::ConstVec b, typename TTypes::ConstMatrix i, \ + typename TTypes::ConstMatrix cs, typename TTypes::ConstMatrix f, \ + typename TTypes::ConstMatrix o, typename TTypes::ConstMatrix ci, \ + typename TTypes::ConstMatrix co, \ + typename TTypes::ConstMatrix cs_grad, \ + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, \ + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, \ + typename TTypes::Matrix df, typename TTypes::Matrix di, \ + typename TTypes::Matrix dgates, \ + typename TTypes::Matrix cs_prev_grad, \ + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ + typename TTypes::Vec wco_grad) { \ + LSTMBlockCellBpropWithEigen( \ + *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \ + i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dgates, \ + cs_prev_grad, wci_grad, wcf_grad, wco_grad); \ + } \ + template struct LSTMBlockCellFprop; \ + template struct LSTMBlockCellBprop; -#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T); +#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T, ICFO); DECLARE_CPU_SPECS(Eigen::half); DECLARE_CPU_SPECS(float); @@ -234,9 +243,9 @@ DECLARE_CPU_SPECS(float); #undef DECLARE_CPU_FBPROP #if GOOGLE_CUDA -#define DECLARE_GPU_FBPROP(T) \ +#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT) \ template <> \ - void LSTMBlockCellFprop::operator()( \ + void LSTMBlockCellFprop::operator()( \ OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \ const float cell_clip, bool use_peephole, \ typename TTypes::ConstMatrix x, \ @@ -250,7 +259,7 @@ DECLARE_CPU_SPECS(float); typename TTypes::Matrix ci, typename TTypes::Matrix co, \ typename TTypes::Matrix gates, typename TTypes::Matrix h); \ template <> \ - void LSTMBlockCellBprop::operator()( \ + void LSTMBlockCellBprop::operator()( \ OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \ typename TTypes::ConstMatrix x, \ typename TTypes::ConstMatrix cs_prev, \ @@ -270,21 +279,20 @@ DECLARE_CPU_SPECS(float); typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ typename TTypes::Vec wco_grad); \ \ - extern template struct LSTMBlockCellBprop; \ - extern template struct LSTMBlockCellFprop; + extern template struct LSTMBlockCellBprop< \ + GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>; \ + extern template struct LSTMBlockCellFprop; -#define DECLARE_GPU_SPECS(T) DECLARE_GPU_FBPROP(T); +#define DECLARE_GPU_SPECS(T) DECLARE_GPU_FBPROP(T, ICFO); -DECLARE_GPU_SPECS(Eigen::half); DECLARE_GPU_SPECS(float); +DECLARE_GPU_SPECS(Eigen::half); #undef DECLARE_GPU_SPECS -#undef DECLARE_GPU_FBPROP +#undef DECLARE_GPU_FBROP #endif // GOOGLE_CUDA - } // namespace functor -template +template class LSTMBlockCellOp : public OpKernel { public: explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -406,8 +414,8 @@ class LSTMBlockCellOp : public OpKernel { const Device& device = ctx->eigen_device(); - functor::LSTMBlockCellFprop(batch_size, input_size, - cell_size)( + functor::LSTMBlockCellFprop( + batch_size, input_size, cell_size)( ctx, device, forget_bias_, cell_clip_, use_peephole_, x_tensor->matrix(), cs_prev_tensor->matrix(), h_prev_tensor->matrix(), w_tensor->matrix(), wci_tensor->vec(), @@ -427,7 +435,7 @@ class LSTMBlockCellOp : public OpKernel { #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint("T"), \ - LSTMBlockCellOp); + LSTMBlockCellOp); REGISTER_KERNEL(Eigen::half); REGISTER_KERNEL(float); @@ -437,14 +445,14 @@ REGISTER_KERNEL(float); #define REGISTER_GPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint("T"), \ - LSTMBlockCellOp); + LSTMBlockCellOp); REGISTER_GPU_KERNEL(Eigen::half); REGISTER_GPU_KERNEL(float); #undef REGISTER_GPU_KERNEL #endif // GOOGLE_CUDA -template +template class LSTMBlockCellGradOp : public OpKernel { public: explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -668,8 +676,8 @@ class LSTMBlockCellGradOp : public OpKernel { functor::TensorZero()(device, wcf_grad_tensor->flat()); functor::TensorZero()(device, wco_grad_tensor->flat()); - functor::LSTMBlockCellBprop(batch_size, input_size, - cell_size)( + functor::LSTMBlockCellBprop( + batch_size, input_size, cell_size)( ctx, device, use_peephole_, x_tensor->matrix(), cs_prev_tensor->matrix(), h_prev_tensor->matrix(), w_tensor->matrix(), wci_tensor->vec(), wcf_tensor->vec(), @@ -691,7 +699,7 @@ class LSTMBlockCellGradOp : public OpKernel { #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ - LSTMBlockCellGradOp); + LSTMBlockCellGradOp); REGISTER_KERNEL(float); REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL @@ -700,7 +708,7 @@ REGISTER_KERNEL(Eigen::half); #define REGISTER_GPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ - LSTMBlockCellGradOp); + LSTMBlockCellGradOp); REGISTER_GPU_KERNEL(Eigen::half); REGISTER_GPU_KERNEL(float); @@ -815,7 +823,7 @@ class SliceHelper { } // namespace -template +template class BlockLSTMOp : public OpKernel { public: explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -972,8 +980,8 @@ class BlockLSTMOp : public OpKernel { Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out"); Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out"); - functor::LSTMBlockCellFprop(batch_size, input_size, - cell_size)( + functor::LSTMBlockCellFprop( + batch_size, input_size, cell_size)( ctx, device, forget_bias_, cell_clip_, use_peephole_, x_tensor.matrix(), cs_prev_tensor2.matrix(), h_prev_tensor2.matrix(), w_tensor->matrix(), @@ -1005,7 +1013,7 @@ class BlockLSTMOp : public OpKernel { #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint("T"), \ - BlockLSTMOp); + BlockLSTMOp); REGISTER_KERNEL(Eigen::half); REGISTER_KERNEL(float); @@ -1036,14 +1044,14 @@ DECLARE_GPU_SPECS(float); .Device(DEVICE_GPU) \ .HostMemory("seq_len_max") \ .TypeConstraint("T"), \ - BlockLSTMOp); + BlockLSTMOp); REGISTER_GPU_KERNEL(Eigen::half); REGISTER_GPU_KERNEL(float); #undef REGISTER_GPU_KERNEL #endif // GOOGLE_CUDA -template +template class BlockLSTMGradOp : public OpKernel { public: explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -1246,8 +1254,8 @@ class BlockLSTMGradOp : public OpKernel { const Tensor& const_h_grad_tensor = h_grad_tensor; Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad"); - functor::BlockLSTMBprop(batch_size, input_size, - cell_size)( + functor::BlockLSTMBprop( + batch_size, input_size, cell_size)( ctx, device, use_peephole_, x_tensor.matrix(), cs_prev_tensor2.matrix(), h_prev_tensor2.matrix(), w_tensor->matrix(), wci_tensor->vec(), wcf_tensor->vec(), @@ -1279,7 +1287,7 @@ class BlockLSTMGradOp : public OpKernel { #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ - BlockLSTMGradOp); + BlockLSTMGradOp); REGISTER_KERNEL(Eigen::half); REGISTER_KERNEL(float); @@ -1287,9 +1295,9 @@ REGISTER_KERNEL(float); #if GOOGLE_CUDA namespace functor { -#define DECLARE_GPU_BPROP(T) \ +#define DECLARE_GPU_BPROP(T, GATE_LAYOUT) \ template <> \ - void BlockLSTMBprop::operator()( \ + void BlockLSTMBprop::operator()( \ OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \ typename TTypes::ConstMatrix x, \ typename TTypes::ConstMatrix cs_prev, \ @@ -1311,7 +1319,7 @@ namespace functor { typename TTypes::Matrix w_grad, typename TTypes::Vec wci_grad, \ typename TTypes::Vec wcf_grad, typename TTypes::Vec wco_grad, \ typename TTypes::Vec b_grad); \ - extern template struct BlockLSTMBprop; + extern template struct BlockLSTMBprop; #define DECLARE_GPU_SPECS(T) \ template <> \ @@ -1337,7 +1345,7 @@ namespace functor { extern template struct TensorCopy; \ extern template struct TensorAdd; \ \ - DECLARE_GPU_BPROP(T); + DECLARE_GPU_BPROP(T, ICFO); DECLARE_GPU_SPECS(Eigen::half); DECLARE_GPU_SPECS(float); @@ -1350,7 +1358,7 @@ DECLARE_GPU_SPECS(float); .Device(DEVICE_GPU) \ .HostMemory("seq_len_max") \ .TypeConstraint("T"), \ - BlockLSTMGradOp); + BlockLSTMGradOp); REGISTER_GPU_KERNEL(Eigen::half); REGISTER_GPU_KERNEL(float); diff --git a/tensorflow/core/kernels/rnn/lstm_ops.h b/tensorflow/core/kernels/rnn/lstm_ops.h index fd069f6512a..834a9231433 100644 --- a/tensorflow/core/kernels/rnn/lstm_ops.h +++ b/tensorflow/core/kernels/rnn/lstm_ops.h @@ -25,6 +25,16 @@ limitations under the License. namespace tensorflow { class OpKernelContext; +enum GateLayout { ICFO, IFCO }; + +constexpr int gate_c_offset(GateLayout gate_layout, int cell_size) { + return (gate_layout == ICFO) ? cell_size : cell_size * 2; +} + +constexpr int gate_f_offset(GateLayout gate_layout, int cell_size) { + return (gate_layout == ICFO) ? cell_size * 2 : cell_size; +} + namespace functor { template @@ -107,12 +117,14 @@ struct LSTMBlockCell { return {0, 0}; } - inline Eigen::array gates_c_offsets() const { - return {0, cell_size_}; + inline Eigen::array gates_c_offsets( + const GateLayout gate_layout) const { + return {0, gate_c_offset(gate_layout, cell_size_)}; } - inline Eigen::array gates_f_offsets() const { - return {0, cell_size_ * 2}; + inline Eigen::array gates_f_offsets( + const GateLayout gate_layout) const { + return {0, gate_f_offset(gate_layout, cell_size_)}; } inline Eigen::array gates_o_offsets() const { @@ -147,7 +159,7 @@ struct LSTMBlockCell { // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for // GPUDevice implementation. -template +template struct LSTMBlockCellFprop : public LSTMBlockCell { LSTMBlockCellFprop(const int batch_size, const int input_size, const int cell_size) @@ -172,7 +184,7 @@ struct LSTMBlockCellFprop : public LSTMBlockCell { // See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for // GPUDevice implementation. -template +template struct LSTMBlockCellBprop : public LSTMBlockCell { LSTMBlockCellBprop(const int batch_size, const int input_size, const int cell_size) @@ -197,7 +209,7 @@ struct LSTMBlockCellBprop : public LSTMBlockCell { typename TTypes::Vec wcf_grad, typename TTypes::Vec wco_grad); }; -template +template struct BlockLSTMBprop : public LSTMBlockCell { BlockLSTMBprop(const int batch_size, const int input_size, const int cell_size) @@ -248,8 +260,8 @@ struct BlockLSTMBprop : public LSTMBlockCell { di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; dgates.slice(gates_i_offsets(), cell_extents()).device(d) = di; - dgates.slice(gates_c_offsets(), cell_extents()).device(d) = dci; - dgates.slice(gates_f_offsets(), cell_extents()).device(d) = df; + dgates.slice(gates_c_offsets(gate_layout), cell_extents()).device(d) = dci; + dgates.slice(gates_f_offsets(gate_layout), cell_extents()).device(d) = df; dgates.slice(gates_o_offsets(), cell_extents()).device(d) = do_; cs_prev_grad.device(d) = dcs * f; diff --git a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc index 577b05791a1..9e872681f60 100644 --- a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc @@ -81,7 +81,7 @@ namespace { // Launch with blocks of (batch x 32) // // TODO(b/67600500): Try making 'use_peephole' a template parameter. -template +template __global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev, const T* wci, const T* wcf, const T* wco, T* o, T* h, T* ci, T* cs, T* co, T* i, T* f, @@ -156,18 +156,19 @@ __global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev, } i[cid] = i_local; - const T ci_local = - tanh_op(gates[1 * cell_size + gid] + b[1 * cell_size + act_id]); + const int c_offset = gate_c_offset(gate_layout, cell_size); + const int f_offset = gate_f_offset(gate_layout, cell_size); + + const T ci_local = tanh_op(gates[c_offset + gid] + b[c_offset + act_id]); ci[cid] = ci_local; T f_local; if (use_peephole) { - f_local = - sigmoid_op(gates[2 * cell_size + gid] + b[2 * cell_size + act_id] + - forget_bias_t + cs_prev[cid] * wcf[act_id]); + f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] + + forget_bias_t + cs_prev[cid] * wcf[act_id]); } else { - f_local = sigmoid_op(gates[2 * cell_size + gid] + - b[2 * cell_size + act_id] + forget_bias_t); + f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] + + forget_bias_t); } f[cid] = f_local; @@ -222,7 +223,7 @@ __global__ void concat_xh(T* xh, const T* x, const T* h_prev, } } -template +template void LSTMBlockCellFpropWithCUDA( OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, const float cell_clip, bool use_peephole, typename TTypes::ConstMatrix x, @@ -267,20 +268,22 @@ void LSTMBlockCellFpropWithCUDA( if (use_peephole) { TF_CHECK_OK(GpuLaunchKernel( - lstm_gates, grid_dim_2d, block_dim_2d, 0, cu_stream, - gates.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), - wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), - i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size)); + lstm_gates, grid_dim_2d, block_dim_2d, 0, + cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(), + wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(), + co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size, + cell_size)); } else { TF_CHECK_OK(GpuLaunchKernel( - lstm_gates, grid_dim_2d, block_dim_2d, 0, cu_stream, - gates.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), - wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), - i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size)); + lstm_gates, grid_dim_2d, block_dim_2d, 0, + cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(), + wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(), + co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size, + cell_size)); } } -template +template __global__ void lstm_gates_bprop( const T* cs_prev, // [batch_size, cell_size] const T* h_prev, // [batch_size, cell_size] @@ -347,8 +350,8 @@ __global__ void lstm_gates_bprop( di[cid] = di_local; dgates[gid + 0 * cell_size] = di_local; - dgates[gid + 1 * cell_size] = dci_local; - dgates[gid + 2 * cell_size] = df_local; + dgates[gate_c_offset(gate_layout, cell_size)] = dci_local; + dgates[gate_f_offset(gate_layout, cell_size)] = df_local; dgates[gid + 3 * cell_size] = do_local; cs_prev_grad[cid] = dcs_local * f_local; @@ -357,7 +360,7 @@ __global__ void lstm_gates_bprop( } } -template +template void LSTMBlockCellBpropWithCUDA( OpKernelContext* ctx, const GPUDevice& d, typename TTypes::ConstMatrix x, typename TTypes::ConstMatrix cs_prev, @@ -382,7 +385,7 @@ void LSTMBlockCellBpropWithCUDA( Eigen::divup(cell_size, static_cast(block_dim_2d.y))); TF_CHECK_OK(GpuLaunchKernel( - lstm_gates_bprop, grid_dim_2d, block_dim_2d, 0, cu_stream, + lstm_gates_bprop, grid_dim_2d, block_dim_2d, 0, cu_stream, cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(), wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(), co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(), @@ -403,54 +406,59 @@ void LSTMBlockCellBpropWithCUDA( } // namespace -#define DECLARE_GPU_FBPROP(T) \ - template <> \ - void LSTMBlockCellFprop::operator()( \ - OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \ - const float cell_clip, bool use_peephole, \ - typename TTypes::ConstMatrix x, \ - typename TTypes::ConstMatrix cs_prev, \ - typename TTypes::ConstMatrix h_prev, \ - typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ - typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ - typename TTypes::ConstVec b, typename TTypes::Matrix xh, \ - typename TTypes::Matrix i, typename TTypes::Matrix cs, \ - typename TTypes::Matrix f, typename TTypes::Matrix o, \ - typename TTypes::Matrix ci, typename TTypes::Matrix co, \ - typename TTypes::Matrix gates, typename TTypes::Matrix h) { \ - LSTMBlockCellFpropWithCUDA(ctx, d, forget_bias, cell_clip, \ - use_peephole, x, cs_prev, h_prev, w, wci, \ - wcf, wco, b, xh, i, cs, f, o, ci, co, gates, \ - h, batch_size_, cell_size_, input_size_); \ - } \ - template <> \ - void LSTMBlockCellBprop::operator()( \ - OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \ - typename TTypes::ConstMatrix x, \ - typename TTypes::ConstMatrix cs_prev, \ - typename TTypes::ConstMatrix h_prev, \ - typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ - typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ - typename TTypes::ConstVec b, typename TTypes::ConstMatrix i, \ - typename TTypes::ConstMatrix cs, typename TTypes::ConstMatrix f, \ - typename TTypes::ConstMatrix o, typename TTypes::ConstMatrix ci, \ - typename TTypes::ConstMatrix co, \ - typename TTypes::ConstMatrix cs_grad, \ - typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, \ - typename TTypes::Matrix dcs, typename TTypes::Matrix dci, \ - typename TTypes::Matrix df, typename TTypes::Matrix di, \ - typename TTypes::Matrix dgates, \ - typename TTypes::Matrix cs_prev_grad, \ - typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ - typename TTypes::Vec wco_grad) { \ - LSTMBlockCellBpropWithCUDA( \ - ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \ - cs_grad, h_grad, do_, dcs, dci, df, di, dgates, cs_prev_grad, \ - wci_grad, wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \ - } \ - template struct LSTMBlockCellFprop; \ - template struct LSTMBlockCellBprop; \ - template struct BlockLSTMBprop; +#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT) \ + template <> \ + void LSTMBlockCellFprop:: \ + operator()( \ + OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \ + const float cell_clip, bool use_peephole, \ + typename TTypes::ConstMatrix x, \ + typename TTypes::ConstMatrix cs_prev, \ + typename TTypes::ConstMatrix h_prev, \ + typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ + typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ + typename TTypes::ConstVec b, typename TTypes::Matrix xh, \ + typename TTypes::Matrix i, typename TTypes::Matrix cs, \ + typename TTypes::Matrix f, typename TTypes::Matrix o, \ + typename TTypes::Matrix ci, typename TTypes::Matrix co, \ + typename TTypes::Matrix gates, typename TTypes::Matrix h) { \ + LSTMBlockCellFpropWithCUDA( \ + ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, h_prev, w, \ + wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h, batch_size_, \ + cell_size_, input_size_); \ + } \ + template <> \ + void LSTMBlockCellBprop:: \ + operator()( \ + OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \ + typename TTypes::ConstMatrix x, \ + typename TTypes::ConstMatrix cs_prev, \ + typename TTypes::ConstMatrix h_prev, \ + typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ + typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ + typename TTypes::ConstVec b, typename TTypes::ConstMatrix i, \ + typename TTypes::ConstMatrix cs, typename TTypes::ConstMatrix f, \ + typename TTypes::ConstMatrix o, typename TTypes::ConstMatrix ci, \ + typename TTypes::ConstMatrix co, \ + typename TTypes::ConstMatrix cs_grad, \ + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, \ + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, \ + typename TTypes::Matrix df, typename TTypes::Matrix di, \ + typename TTypes::Matrix dgates, \ + typename TTypes::Matrix cs_prev_grad, \ + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ + typename TTypes::Vec wco_grad) { \ + LSTMBlockCellBpropWithCUDA( \ + ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \ + cs_grad, h_grad, do_, dcs, dci, df, di, dgates, cs_prev_grad, \ + wci_grad, wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \ + } \ + template struct LSTMBlockCellFprop; \ + template struct LSTMBlockCellBprop; \ + template struct BlockLSTMBprop; #define DECLARE_GPU_SPECS(T) \ template struct TensorZero; \ @@ -460,10 +468,10 @@ void LSTMBlockCellBpropWithCUDA( template struct TensorCopyToUnaligned; \ template struct TensorAdd; \ \ - DECLARE_GPU_FBPROP(T); + DECLARE_GPU_FBPROP(T, ICFO); -DECLARE_GPU_SPECS(float); DECLARE_GPU_SPECS(Eigen::half); +DECLARE_GPU_SPECS(float); #undef DECLARE_GPU_SPECS #undef DECLARE_GPU_FBPROP } // end namespace functor