BlockLSTM and LSTMBlockCell are now generic wrt to the gate layout

The default layout is still ICFO, but V2 version of the op will use IFCO
to match CuDNN-RNN.

PiperOrigin-RevId: 262538325
This commit is contained in:
Sergei Lebedev 2019-08-09 04:48:29 -07:00 committed by TensorFlower Gardener
parent 626d98bf78
commit ed0eee6898
3 changed files with 201 additions and 173 deletions

View File

@ -41,7 +41,7 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
template <typename T>
template <typename T, GateLayout gate_layout>
void LSTMBlockCellFpropWithEigen(
const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
const float forget_bias, const float cell_clip, bool use_peephole,
@ -83,18 +83,21 @@ void LSTMBlockCellFpropWithEigen(
// Cell input.
ci.device(d) =
gates.slice(cell.gates_c_offsets(), cell.cell_extents()).tanh();
gates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents())
.tanh();
// Forget gate (w/ bias).
if (use_peephole) {
auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
f.device(d) = (gates.slice(cell.gates_f_offsets(), cell.cell_extents()) +
f.constant(T(forget_bias)) + f_peep)
.sigmoid();
f.device(d) =
(gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) +
f.constant(T(forget_bias)) + f_peep)
.sigmoid();
} else {
f.device(d) = (gates.slice(cell.gates_f_offsets(), cell.cell_extents()) +
f.constant(T(forget_bias)))
.sigmoid();
f.device(d) =
(gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) +
f.constant(T(forget_bias)))
.sigmoid();
}
// cs = ci .* i + f .* cs_prev
@ -123,7 +126,7 @@ void LSTMBlockCellFpropWithEigen(
h.device(d) = o * co;
}
template <typename Device, typename T>
template <typename Device, typename T, GateLayout gate_layout>
void LSTMBlockCellBpropWithEigen(
const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
bool use_peephole, typename TTypes<T>::ConstMatrix x,
@ -164,8 +167,10 @@ void LSTMBlockCellBpropWithEigen(
di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
dgates.slice(cell.gates_i_offsets(), cell.cell_extents()).device(d) = di;
dgates.slice(cell.gates_c_offsets(), cell.cell_extents()).device(d) = dci;
dgates.slice(cell.gates_f_offsets(), cell.cell_extents()).device(d) = df;
dgates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents())
.device(d) = dci;
dgates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents())
.device(d) = df;
dgates.slice(cell.gates_o_offsets(), cell.cell_extents()).device(d) = do_;
cs_prev_grad.device(d) = dcs * f;
@ -179,54 +184,58 @@ void LSTMBlockCellBpropWithEigen(
}
}
#define DECLARE_CPU_FBPROP(T) \
template <> \
void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \
OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \
const float cell_clip, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
LSTMBlockCellFpropWithEigen<T>( \
*this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \
h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h); \
} \
template <> \
void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \
OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
typename TTypes<T>::ConstMatrix co, \
typename TTypes<T>::ConstMatrix cs_grad, \
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
typename TTypes<T>::Matrix dgates, \
typename TTypes<T>::Matrix cs_prev_grad, \
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
typename TTypes<T>::Vec wco_grad) { \
LSTMBlockCellBpropWithEigen<CPUDevice, T>( \
*this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dgates, \
cs_prev_grad, wci_grad, wcf_grad, wco_grad); \
} \
template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>; \
template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>;
#define DECLARE_CPU_FBPROP(T, GATE_LAYOUT) \
template <> \
void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */, GATE_LAYOUT>:: \
operator()( \
OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \
const float cell_clip, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
LSTMBlockCellFpropWithEigen<T, GATE_LAYOUT>( \
*this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \
h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h); \
} \
template <> \
void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */, GATE_LAYOUT>:: \
operator()( \
OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
typename TTypes<T>::ConstMatrix co, \
typename TTypes<T>::ConstMatrix cs_grad, \
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
typename TTypes<T>::Matrix dgates, \
typename TTypes<T>::Matrix cs_prev_grad, \
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
typename TTypes<T>::Vec wco_grad) { \
LSTMBlockCellBpropWithEigen<CPUDevice, T, GATE_LAYOUT>( \
*this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dgates, \
cs_prev_grad, wci_grad, wcf_grad, wco_grad); \
} \
template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */, \
GATE_LAYOUT>; \
template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */, \
GATE_LAYOUT>;
#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T);
#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T, ICFO);
DECLARE_CPU_SPECS(Eigen::half);
DECLARE_CPU_SPECS(float);
@ -234,9 +243,9 @@ DECLARE_CPU_SPECS(float);
#undef DECLARE_CPU_FBPROP
#if GOOGLE_CUDA
#define DECLARE_GPU_FBPROP(T) \
#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT) \
template <> \
void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
void LSTMBlockCellFprop<GPUDevice, T, true, GATE_LAYOUT>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
const float cell_clip, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
@ -250,7 +259,7 @@ DECLARE_CPU_SPECS(float);
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h); \
template <> \
void LSTMBlockCellBprop<GPUDevice, T, true>::operator()( \
void LSTMBlockCellBprop<GPUDevice, T, true, GATE_LAYOUT>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
@ -270,21 +279,20 @@ DECLARE_CPU_SPECS(float);
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
typename TTypes<T>::Vec wco_grad); \
\
extern template struct LSTMBlockCellBprop<GPUDevice, T, \
true /* USE_CUBLAS */>; \
extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
extern template struct LSTMBlockCellBprop< \
GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>; \
extern template struct LSTMBlockCellFprop<GPUDevice, T, true, GATE_LAYOUT>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_FBPROP(T);
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_FBPROP(T, ICFO);
DECLARE_GPU_SPECS(Eigen::half);
DECLARE_GPU_SPECS(float);
DECLARE_GPU_SPECS(Eigen::half);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_FBPROP
#undef DECLARE_GPU_FBROP
#endif // GOOGLE_CUDA
} // namespace functor
template <typename Device, typename T, bool USE_CUBLAS>
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
class LSTMBlockCellOp : public OpKernel {
public:
explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -406,8 +414,8 @@ class LSTMBlockCellOp : public OpKernel {
const Device& device = ctx->eigen_device<Device>();
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
batch_size, input_size, cell_size)(
ctx, device, forget_bias_, cell_clip_, use_peephole_,
x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
@ -427,7 +435,7 @@ class LSTMBlockCellOp : public OpKernel {
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
LSTMBlockCellOp<CPUDevice, T, false>);
LSTMBlockCellOp<CPUDevice, T, false, ICFO>);
REGISTER_KERNEL(Eigen::half);
REGISTER_KERNEL(float);
@ -437,14 +445,14 @@ REGISTER_KERNEL(float);
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
LSTMBlockCellOp<GPUDevice, T, true>);
LSTMBlockCellOp<GPUDevice, T, true, ICFO>);
REGISTER_GPU_KERNEL(Eigen::half);
REGISTER_GPU_KERNEL(float);
#undef REGISTER_GPU_KERNEL
#endif // GOOGLE_CUDA
template <typename Device, typename T, bool USE_CUBLAS>
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
class LSTMBlockCellGradOp : public OpKernel {
public:
explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -668,8 +676,8 @@ class LSTMBlockCellGradOp : public OpKernel {
functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS, gate_layout>(
batch_size, input_size, cell_size)(
ctx, device, use_peephole_, x_tensor->matrix<T>(),
cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
@ -691,7 +699,7 @@ class LSTMBlockCellGradOp : public OpKernel {
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
LSTMBlockCellGradOp<CPUDevice, T, false>);
LSTMBlockCellGradOp<CPUDevice, T, false, ICFO>);
REGISTER_KERNEL(float);
REGISTER_KERNEL(Eigen::half);
#undef REGISTER_KERNEL
@ -700,7 +708,7 @@ REGISTER_KERNEL(Eigen::half);
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
LSTMBlockCellGradOp<GPUDevice, T, true>);
LSTMBlockCellGradOp<GPUDevice, T, true, ICFO>);
REGISTER_GPU_KERNEL(Eigen::half);
REGISTER_GPU_KERNEL(float);
@ -815,7 +823,7 @@ class SliceHelper {
} // namespace
template <typename Device, typename T, bool USE_CUBLAS>
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
class BlockLSTMOp : public OpKernel {
public:
explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -972,8 +980,8 @@ class BlockLSTMOp : public OpKernel {
Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
batch_size, input_size, cell_size)(
ctx, device, forget_bias_, cell_clip_, use_peephole_,
x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
@ -1005,7 +1013,7 @@ class BlockLSTMOp : public OpKernel {
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
BlockLSTMOp<CPUDevice, T, false>);
BlockLSTMOp<CPUDevice, T, false, ICFO>);
REGISTER_KERNEL(Eigen::half);
REGISTER_KERNEL(float);
@ -1036,14 +1044,14 @@ DECLARE_GPU_SPECS(float);
.Device(DEVICE_GPU) \
.HostMemory("seq_len_max") \
.TypeConstraint<T>("T"), \
BlockLSTMOp<GPUDevice, T, true>);
BlockLSTMOp<GPUDevice, T, true, ICFO>);
REGISTER_GPU_KERNEL(Eigen::half);
REGISTER_GPU_KERNEL(float);
#undef REGISTER_GPU_KERNEL
#endif // GOOGLE_CUDA
template <typename Device, typename T, bool USE_CUBLAS>
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
class BlockLSTMGradOp : public OpKernel {
public:
explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -1246,8 +1254,8 @@ class BlockLSTMGradOp : public OpKernel {
const Tensor& const_h_grad_tensor = h_grad_tensor;
Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
functor::BlockLSTMBprop<Device, T, USE_CUBLAS, gate_layout>(
batch_size, input_size, cell_size)(
ctx, device, use_peephole_, x_tensor.matrix<T>(),
cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
@ -1279,7 +1287,7 @@ class BlockLSTMGradOp : public OpKernel {
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
BlockLSTMGradOp<CPUDevice, T, false>);
BlockLSTMGradOp<CPUDevice, T, false, ICFO>);
REGISTER_KERNEL(Eigen::half);
REGISTER_KERNEL(float);
@ -1287,9 +1295,9 @@ REGISTER_KERNEL(float);
#if GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_BPROP(T) \
#define DECLARE_GPU_BPROP(T, GATE_LAYOUT) \
template <> \
void BlockLSTMBprop<GPUDevice, T, true>::operator()( \
void BlockLSTMBprop<GPUDevice, T, true, GATE_LAYOUT>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
@ -1311,7 +1319,7 @@ namespace functor {
typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, \
typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, \
typename TTypes<T>::Vec b_grad); \
extern template struct BlockLSTMBprop<GPUDevice, T, true>;
extern template struct BlockLSTMBprop<GPUDevice, T, true, GATE_LAYOUT>;
#define DECLARE_GPU_SPECS(T) \
template <> \
@ -1337,7 +1345,7 @@ namespace functor {
extern template struct TensorCopy<GPUDevice, T>; \
extern template struct TensorAdd<GPUDevice, T>; \
\
DECLARE_GPU_BPROP(T);
DECLARE_GPU_BPROP(T, ICFO);
DECLARE_GPU_SPECS(Eigen::half);
DECLARE_GPU_SPECS(float);
@ -1350,7 +1358,7 @@ DECLARE_GPU_SPECS(float);
.Device(DEVICE_GPU) \
.HostMemory("seq_len_max") \
.TypeConstraint<T>("T"), \
BlockLSTMGradOp<GPUDevice, T, true>);
BlockLSTMGradOp<GPUDevice, T, true, ICFO>);
REGISTER_GPU_KERNEL(Eigen::half);
REGISTER_GPU_KERNEL(float);

View File

@ -25,6 +25,16 @@ limitations under the License.
namespace tensorflow {
class OpKernelContext;
enum GateLayout { ICFO, IFCO };
constexpr int gate_c_offset(GateLayout gate_layout, int cell_size) {
return (gate_layout == ICFO) ? cell_size : cell_size * 2;
}
constexpr int gate_f_offset(GateLayout gate_layout, int cell_size) {
return (gate_layout == ICFO) ? cell_size * 2 : cell_size;
}
namespace functor {
template <typename Device, typename T>
@ -107,12 +117,14 @@ struct LSTMBlockCell {
return {0, 0};
}
inline Eigen::array<Eigen::DenseIndex, 2> gates_c_offsets() const {
return {0, cell_size_};
inline Eigen::array<Eigen::DenseIndex, 2> gates_c_offsets(
const GateLayout gate_layout) const {
return {0, gate_c_offset(gate_layout, cell_size_)};
}
inline Eigen::array<Eigen::DenseIndex, 2> gates_f_offsets() const {
return {0, cell_size_ * 2};
inline Eigen::array<Eigen::DenseIndex, 2> gates_f_offsets(
const GateLayout gate_layout) const {
return {0, gate_f_offset(gate_layout, cell_size_)};
}
inline Eigen::array<Eigen::DenseIndex, 2> gates_o_offsets() const {
@ -147,7 +159,7 @@ struct LSTMBlockCell {
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
// GPUDevice implementation.
template <typename Device, typename T, bool USE_CUBLAS>
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
struct LSTMBlockCellFprop : public LSTMBlockCell {
LSTMBlockCellFprop(const int batch_size, const int input_size,
const int cell_size)
@ -172,7 +184,7 @@ struct LSTMBlockCellFprop : public LSTMBlockCell {
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
// GPUDevice implementation.
template <typename Device, typename T, bool USE_CUBLAS>
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
struct LSTMBlockCellBprop : public LSTMBlockCell {
LSTMBlockCellBprop(const int batch_size, const int input_size,
const int cell_size)
@ -197,7 +209,7 @@ struct LSTMBlockCellBprop : public LSTMBlockCell {
typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad);
};
template <typename Device, typename T, bool USE_CUBLAS>
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
struct BlockLSTMBprop : public LSTMBlockCell {
BlockLSTMBprop(const int batch_size, const int input_size,
const int cell_size)
@ -248,8 +260,8 @@ struct BlockLSTMBprop : public LSTMBlockCell {
di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
dgates.slice(gates_i_offsets(), cell_extents()).device(d) = di;
dgates.slice(gates_c_offsets(), cell_extents()).device(d) = dci;
dgates.slice(gates_f_offsets(), cell_extents()).device(d) = df;
dgates.slice(gates_c_offsets(gate_layout), cell_extents()).device(d) = dci;
dgates.slice(gates_f_offsets(gate_layout), cell_extents()).device(d) = df;
dgates.slice(gates_o_offsets(), cell_extents()).device(d) = do_;
cs_prev_grad.device(d) = dcs * f;

View File

@ -81,7 +81,7 @@ namespace {
// Launch with blocks of (batch x 32)
//
// TODO(b/67600500): Try making 'use_peephole' a template parameter.
template <typename T, bool use_peephole>
template <typename T, bool use_peephole, GateLayout gate_layout>
__global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev,
const T* wci, const T* wcf, const T* wco, T* o, T* h,
T* ci, T* cs, T* co, T* i, T* f,
@ -156,18 +156,19 @@ __global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev,
}
i[cid] = i_local;
const T ci_local =
tanh_op(gates[1 * cell_size + gid] + b[1 * cell_size + act_id]);
const int c_offset = gate_c_offset(gate_layout, cell_size);
const int f_offset = gate_f_offset(gate_layout, cell_size);
const T ci_local = tanh_op(gates[c_offset + gid] + b[c_offset + act_id]);
ci[cid] = ci_local;
T f_local;
if (use_peephole) {
f_local =
sigmoid_op(gates[2 * cell_size + gid] + b[2 * cell_size + act_id] +
forget_bias_t + cs_prev[cid] * wcf[act_id]);
f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] +
forget_bias_t + cs_prev[cid] * wcf[act_id]);
} else {
f_local = sigmoid_op(gates[2 * cell_size + gid] +
b[2 * cell_size + act_id] + forget_bias_t);
f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] +
forget_bias_t);
}
f[cid] = f_local;
@ -222,7 +223,7 @@ __global__ void concat_xh(T* xh, const T* x, const T* h_prev,
}
}
template <typename T>
template <typename T, GateLayout gate_layout>
void LSTMBlockCellFpropWithCUDA(
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,
const float cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
@ -267,20 +268,22 @@ void LSTMBlockCellFpropWithCUDA(
if (use_peephole) {
TF_CHECK_OK(GpuLaunchKernel(
lstm_gates<T, true>, grid_dim_2d, block_dim_2d, 0, cu_stream,
gates.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
lstm_gates<T, true, gate_layout>, grid_dim_2d, block_dim_2d, 0,
cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(),
wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(),
co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size,
cell_size));
} else {
TF_CHECK_OK(GpuLaunchKernel(
lstm_gates<T, false>, grid_dim_2d, block_dim_2d, 0, cu_stream,
gates.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
lstm_gates<T, false, gate_layout>, grid_dim_2d, block_dim_2d, 0,
cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(),
wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(),
co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size,
cell_size));
}
}
template <typename T>
template <typename T, GateLayout gate_layout>
__global__ void lstm_gates_bprop(
const T* cs_prev, // [batch_size, cell_size]
const T* h_prev, // [batch_size, cell_size]
@ -347,8 +350,8 @@ __global__ void lstm_gates_bprop(
di[cid] = di_local;
dgates[gid + 0 * cell_size] = di_local;
dgates[gid + 1 * cell_size] = dci_local;
dgates[gid + 2 * cell_size] = df_local;
dgates[gate_c_offset(gate_layout, cell_size)] = dci_local;
dgates[gate_f_offset(gate_layout, cell_size)] = df_local;
dgates[gid + 3 * cell_size] = do_local;
cs_prev_grad[cid] = dcs_local * f_local;
@ -357,7 +360,7 @@ __global__ void lstm_gates_bprop(
}
}
template <typename T>
template <typename T, GateLayout gate_layout>
void LSTMBlockCellBpropWithCUDA(
OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
@ -382,7 +385,7 @@ void LSTMBlockCellBpropWithCUDA(
Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));
TF_CHECK_OK(GpuLaunchKernel(
lstm_gates_bprop<T>, grid_dim_2d, block_dim_2d, 0, cu_stream,
lstm_gates_bprop<T, gate_layout>, grid_dim_2d, block_dim_2d, 0, cu_stream,
cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(),
wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(),
co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(),
@ -403,54 +406,59 @@ void LSTMBlockCellBpropWithCUDA(
} // namespace
#define DECLARE_GPU_FBPROP(T) \
template <> \
void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
const float cell_clip, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
LSTMBlockCellFpropWithCUDA<T>(ctx, d, forget_bias, cell_clip, \
use_peephole, x, cs_prev, h_prev, w, wci, \
wcf, wco, b, xh, i, cs, f, o, ci, co, gates, \
h, batch_size_, cell_size_, input_size_); \
} \
template <> \
void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
typename TTypes<T>::ConstMatrix co, \
typename TTypes<T>::ConstMatrix cs_grad, \
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
typename TTypes<T>::Matrix dgates, \
typename TTypes<T>::Matrix cs_prev_grad, \
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
typename TTypes<T>::Vec wco_grad) { \
LSTMBlockCellBpropWithCUDA<T>( \
ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \
cs_grad, h_grad, do_, dcs, dci, df, di, dgates, cs_prev_grad, \
wci_grad, wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \
} \
template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>; \
template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>; \
template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */>;
#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT) \
template <> \
void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>:: \
operator()( \
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
const float cell_clip, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
LSTMBlockCellFpropWithCUDA<T, GATE_LAYOUT>( \
ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, h_prev, w, \
wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h, batch_size_, \
cell_size_, input_size_); \
} \
template <> \
void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>:: \
operator()( \
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
typename TTypes<T>::ConstMatrix co, \
typename TTypes<T>::ConstMatrix cs_grad, \
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
typename TTypes<T>::Matrix dgates, \
typename TTypes<T>::Matrix cs_prev_grad, \
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
typename TTypes<T>::Vec wco_grad) { \
LSTMBlockCellBpropWithCUDA<T, GATE_LAYOUT>( \
ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \
cs_grad, h_grad, do_, dcs, dci, df, di, dgates, cs_prev_grad, \
wci_grad, wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \
} \
template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */, \
GATE_LAYOUT>; \
template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */, \
GATE_LAYOUT>; \
template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */, \
GATE_LAYOUT>;
#define DECLARE_GPU_SPECS(T) \
template struct TensorZero<GPUDevice, T>; \
@ -460,10 +468,10 @@ void LSTMBlockCellBpropWithCUDA(
template struct TensorCopyToUnaligned<GPUDevice, T>; \
template struct TensorAdd<GPUDevice, T>; \
\
DECLARE_GPU_FBPROP(T);
DECLARE_GPU_FBPROP(T, ICFO);
DECLARE_GPU_SPECS(float);
DECLARE_GPU_SPECS(Eigen::half);
DECLARE_GPU_SPECS(float);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_FBPROP
} // end namespace functor