BlockLSTM and LSTMBlockCell are now generic wrt to the gate layout
The default layout is still ICFO, but V2 version of the op will use IFCO to match CuDNN-RNN. PiperOrigin-RevId: 262538325
This commit is contained in:
parent
626d98bf78
commit
ed0eee6898
@ -41,7 +41,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T>
|
||||
template <typename T, GateLayout gate_layout>
|
||||
void LSTMBlockCellFpropWithEigen(
|
||||
const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
|
||||
const float forget_bias, const float cell_clip, bool use_peephole,
|
||||
@ -83,18 +83,21 @@ void LSTMBlockCellFpropWithEigen(
|
||||
|
||||
// Cell input.
|
||||
ci.device(d) =
|
||||
gates.slice(cell.gates_c_offsets(), cell.cell_extents()).tanh();
|
||||
gates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents())
|
||||
.tanh();
|
||||
|
||||
// Forget gate (w/ bias).
|
||||
if (use_peephole) {
|
||||
auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
|
||||
f.device(d) = (gates.slice(cell.gates_f_offsets(), cell.cell_extents()) +
|
||||
f.constant(T(forget_bias)) + f_peep)
|
||||
.sigmoid();
|
||||
f.device(d) =
|
||||
(gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) +
|
||||
f.constant(T(forget_bias)) + f_peep)
|
||||
.sigmoid();
|
||||
} else {
|
||||
f.device(d) = (gates.slice(cell.gates_f_offsets(), cell.cell_extents()) +
|
||||
f.constant(T(forget_bias)))
|
||||
.sigmoid();
|
||||
f.device(d) =
|
||||
(gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) +
|
||||
f.constant(T(forget_bias)))
|
||||
.sigmoid();
|
||||
}
|
||||
|
||||
// cs = ci .* i + f .* cs_prev
|
||||
@ -123,7 +126,7 @@ void LSTMBlockCellFpropWithEigen(
|
||||
h.device(d) = o * co;
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, GateLayout gate_layout>
|
||||
void LSTMBlockCellBpropWithEigen(
|
||||
const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
|
||||
bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
||||
@ -164,8 +167,10 @@ void LSTMBlockCellBpropWithEigen(
|
||||
di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
|
||||
|
||||
dgates.slice(cell.gates_i_offsets(), cell.cell_extents()).device(d) = di;
|
||||
dgates.slice(cell.gates_c_offsets(), cell.cell_extents()).device(d) = dci;
|
||||
dgates.slice(cell.gates_f_offsets(), cell.cell_extents()).device(d) = df;
|
||||
dgates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents())
|
||||
.device(d) = dci;
|
||||
dgates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents())
|
||||
.device(d) = df;
|
||||
dgates.slice(cell.gates_o_offsets(), cell.cell_extents()).device(d) = do_;
|
||||
|
||||
cs_prev_grad.device(d) = dcs * f;
|
||||
@ -179,54 +184,58 @@ void LSTMBlockCellBpropWithEigen(
|
||||
}
|
||||
}
|
||||
|
||||
#define DECLARE_CPU_FBPROP(T) \
|
||||
template <> \
|
||||
void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \
|
||||
OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \
|
||||
const float cell_clip, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
||||
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
||||
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
|
||||
LSTMBlockCellFpropWithEigen<T>( \
|
||||
*this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \
|
||||
h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h); \
|
||||
} \
|
||||
template <> \
|
||||
void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \
|
||||
OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
|
||||
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
|
||||
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
|
||||
typename TTypes<T>::ConstMatrix co, \
|
||||
typename TTypes<T>::ConstMatrix cs_grad, \
|
||||
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
|
||||
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
|
||||
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
|
||||
typename TTypes<T>::Matrix dgates, \
|
||||
typename TTypes<T>::Matrix cs_prev_grad, \
|
||||
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
||||
typename TTypes<T>::Vec wco_grad) { \
|
||||
LSTMBlockCellBpropWithEigen<CPUDevice, T>( \
|
||||
*this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
|
||||
i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dgates, \
|
||||
cs_prev_grad, wci_grad, wcf_grad, wco_grad); \
|
||||
} \
|
||||
template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>; \
|
||||
template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>;
|
||||
#define DECLARE_CPU_FBPROP(T, GATE_LAYOUT) \
|
||||
template <> \
|
||||
void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */, GATE_LAYOUT>:: \
|
||||
operator()( \
|
||||
OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \
|
||||
const float cell_clip, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
||||
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
||||
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
|
||||
LSTMBlockCellFpropWithEigen<T, GATE_LAYOUT>( \
|
||||
*this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \
|
||||
h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h); \
|
||||
} \
|
||||
template <> \
|
||||
void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */, GATE_LAYOUT>:: \
|
||||
operator()( \
|
||||
OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
|
||||
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
|
||||
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
|
||||
typename TTypes<T>::ConstMatrix co, \
|
||||
typename TTypes<T>::ConstMatrix cs_grad, \
|
||||
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
|
||||
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
|
||||
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
|
||||
typename TTypes<T>::Matrix dgates, \
|
||||
typename TTypes<T>::Matrix cs_prev_grad, \
|
||||
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
||||
typename TTypes<T>::Vec wco_grad) { \
|
||||
LSTMBlockCellBpropWithEigen<CPUDevice, T, GATE_LAYOUT>( \
|
||||
*this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
|
||||
i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dgates, \
|
||||
cs_prev_grad, wci_grad, wcf_grad, wco_grad); \
|
||||
} \
|
||||
template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */, \
|
||||
GATE_LAYOUT>; \
|
||||
template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */, \
|
||||
GATE_LAYOUT>;
|
||||
|
||||
#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T);
|
||||
#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T, ICFO);
|
||||
|
||||
DECLARE_CPU_SPECS(Eigen::half);
|
||||
DECLARE_CPU_SPECS(float);
|
||||
@ -234,9 +243,9 @@ DECLARE_CPU_SPECS(float);
|
||||
#undef DECLARE_CPU_FBPROP
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define DECLARE_GPU_FBPROP(T) \
|
||||
#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT) \
|
||||
template <> \
|
||||
void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
|
||||
void LSTMBlockCellFprop<GPUDevice, T, true, GATE_LAYOUT>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
|
||||
const float cell_clip, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
@ -250,7 +259,7 @@ DECLARE_CPU_SPECS(float);
|
||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h); \
|
||||
template <> \
|
||||
void LSTMBlockCellBprop<GPUDevice, T, true>::operator()( \
|
||||
void LSTMBlockCellBprop<GPUDevice, T, true, GATE_LAYOUT>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
@ -270,21 +279,20 @@ DECLARE_CPU_SPECS(float);
|
||||
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
||||
typename TTypes<T>::Vec wco_grad); \
|
||||
\
|
||||
extern template struct LSTMBlockCellBprop<GPUDevice, T, \
|
||||
true /* USE_CUBLAS */>; \
|
||||
extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
|
||||
extern template struct LSTMBlockCellBprop< \
|
||||
GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>; \
|
||||
extern template struct LSTMBlockCellFprop<GPUDevice, T, true, GATE_LAYOUT>;
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_FBPROP(T);
|
||||
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_FBPROP(T, ICFO);
|
||||
|
||||
DECLARE_GPU_SPECS(Eigen::half);
|
||||
DECLARE_GPU_SPECS(float);
|
||||
DECLARE_GPU_SPECS(Eigen::half);
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_FBPROP
|
||||
#undef DECLARE_GPU_FBROP
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace functor
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||
class LSTMBlockCellOp : public OpKernel {
|
||||
public:
|
||||
explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -406,8 +414,8 @@ class LSTMBlockCellOp : public OpKernel {
|
||||
|
||||
const Device& device = ctx->eigen_device<Device>();
|
||||
|
||||
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
|
||||
batch_size, input_size, cell_size)(
|
||||
ctx, device, forget_bias_, cell_clip_, use_peephole_,
|
||||
x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
|
||||
h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
|
||||
@ -427,7 +435,7 @@ class LSTMBlockCellOp : public OpKernel {
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
LSTMBlockCellOp<CPUDevice, T, false>);
|
||||
LSTMBlockCellOp<CPUDevice, T, false, ICFO>);
|
||||
|
||||
REGISTER_KERNEL(Eigen::half);
|
||||
REGISTER_KERNEL(float);
|
||||
@ -437,14 +445,14 @@ REGISTER_KERNEL(float);
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
LSTMBlockCellOp<GPUDevice, T, true>);
|
||||
LSTMBlockCellOp<GPUDevice, T, true, ICFO>);
|
||||
|
||||
REGISTER_GPU_KERNEL(Eigen::half);
|
||||
REGISTER_GPU_KERNEL(float);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||
class LSTMBlockCellGradOp : public OpKernel {
|
||||
public:
|
||||
explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -668,8 +676,8 @@ class LSTMBlockCellGradOp : public OpKernel {
|
||||
functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
|
||||
functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
|
||||
|
||||
functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS, gate_layout>(
|
||||
batch_size, input_size, cell_size)(
|
||||
ctx, device, use_peephole_, x_tensor->matrix<T>(),
|
||||
cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
|
||||
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
|
||||
@ -691,7 +699,7 @@ class LSTMBlockCellGradOp : public OpKernel {
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
LSTMBlockCellGradOp<CPUDevice, T, false>);
|
||||
LSTMBlockCellGradOp<CPUDevice, T, false, ICFO>);
|
||||
REGISTER_KERNEL(float);
|
||||
REGISTER_KERNEL(Eigen::half);
|
||||
#undef REGISTER_KERNEL
|
||||
@ -700,7 +708,7 @@ REGISTER_KERNEL(Eigen::half);
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
LSTMBlockCellGradOp<GPUDevice, T, true>);
|
||||
LSTMBlockCellGradOp<GPUDevice, T, true, ICFO>);
|
||||
|
||||
REGISTER_GPU_KERNEL(Eigen::half);
|
||||
REGISTER_GPU_KERNEL(float);
|
||||
@ -815,7 +823,7 @@ class SliceHelper {
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||
class BlockLSTMOp : public OpKernel {
|
||||
public:
|
||||
explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -972,8 +980,8 @@ class BlockLSTMOp : public OpKernel {
|
||||
Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
|
||||
Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
|
||||
|
||||
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
|
||||
batch_size, input_size, cell_size)(
|
||||
ctx, device, forget_bias_, cell_clip_, use_peephole_,
|
||||
x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
|
||||
h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
|
||||
@ -1005,7 +1013,7 @@ class BlockLSTMOp : public OpKernel {
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
BlockLSTMOp<CPUDevice, T, false>);
|
||||
BlockLSTMOp<CPUDevice, T, false, ICFO>);
|
||||
|
||||
REGISTER_KERNEL(Eigen::half);
|
||||
REGISTER_KERNEL(float);
|
||||
@ -1036,14 +1044,14 @@ DECLARE_GPU_SPECS(float);
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("seq_len_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
BlockLSTMOp<GPUDevice, T, true>);
|
||||
BlockLSTMOp<GPUDevice, T, true, ICFO>);
|
||||
|
||||
REGISTER_GPU_KERNEL(Eigen::half);
|
||||
REGISTER_GPU_KERNEL(float);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||
class BlockLSTMGradOp : public OpKernel {
|
||||
public:
|
||||
explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -1246,8 +1254,8 @@ class BlockLSTMGradOp : public OpKernel {
|
||||
const Tensor& const_h_grad_tensor = h_grad_tensor;
|
||||
|
||||
Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
|
||||
functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
||||
cell_size)(
|
||||
functor::BlockLSTMBprop<Device, T, USE_CUBLAS, gate_layout>(
|
||||
batch_size, input_size, cell_size)(
|
||||
ctx, device, use_peephole_, x_tensor.matrix<T>(),
|
||||
cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
|
||||
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
|
||||
@ -1279,7 +1287,7 @@ class BlockLSTMGradOp : public OpKernel {
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
BlockLSTMGradOp<CPUDevice, T, false>);
|
||||
BlockLSTMGradOp<CPUDevice, T, false, ICFO>);
|
||||
|
||||
REGISTER_KERNEL(Eigen::half);
|
||||
REGISTER_KERNEL(float);
|
||||
@ -1287,9 +1295,9 @@ REGISTER_KERNEL(float);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_BPROP(T) \
|
||||
#define DECLARE_GPU_BPROP(T, GATE_LAYOUT) \
|
||||
template <> \
|
||||
void BlockLSTMBprop<GPUDevice, T, true>::operator()( \
|
||||
void BlockLSTMBprop<GPUDevice, T, true, GATE_LAYOUT>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
@ -1311,7 +1319,7 @@ namespace functor {
|
||||
typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, \
|
||||
typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, \
|
||||
typename TTypes<T>::Vec b_grad); \
|
||||
extern template struct BlockLSTMBprop<GPUDevice, T, true>;
|
||||
extern template struct BlockLSTMBprop<GPUDevice, T, true, GATE_LAYOUT>;
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
template <> \
|
||||
@ -1337,7 +1345,7 @@ namespace functor {
|
||||
extern template struct TensorCopy<GPUDevice, T>; \
|
||||
extern template struct TensorAdd<GPUDevice, T>; \
|
||||
\
|
||||
DECLARE_GPU_BPROP(T);
|
||||
DECLARE_GPU_BPROP(T, ICFO);
|
||||
|
||||
DECLARE_GPU_SPECS(Eigen::half);
|
||||
DECLARE_GPU_SPECS(float);
|
||||
@ -1350,7 +1358,7 @@ DECLARE_GPU_SPECS(float);
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("seq_len_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
BlockLSTMGradOp<GPUDevice, T, true>);
|
||||
BlockLSTMGradOp<GPUDevice, T, true, ICFO>);
|
||||
|
||||
REGISTER_GPU_KERNEL(Eigen::half);
|
||||
REGISTER_GPU_KERNEL(float);
|
||||
|
@ -25,6 +25,16 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
class OpKernelContext;
|
||||
|
||||
enum GateLayout { ICFO, IFCO };
|
||||
|
||||
constexpr int gate_c_offset(GateLayout gate_layout, int cell_size) {
|
||||
return (gate_layout == ICFO) ? cell_size : cell_size * 2;
|
||||
}
|
||||
|
||||
constexpr int gate_f_offset(GateLayout gate_layout, int cell_size) {
|
||||
return (gate_layout == ICFO) ? cell_size * 2 : cell_size;
|
||||
}
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
@ -107,12 +117,14 @@ struct LSTMBlockCell {
|
||||
return {0, 0};
|
||||
}
|
||||
|
||||
inline Eigen::array<Eigen::DenseIndex, 2> gates_c_offsets() const {
|
||||
return {0, cell_size_};
|
||||
inline Eigen::array<Eigen::DenseIndex, 2> gates_c_offsets(
|
||||
const GateLayout gate_layout) const {
|
||||
return {0, gate_c_offset(gate_layout, cell_size_)};
|
||||
}
|
||||
|
||||
inline Eigen::array<Eigen::DenseIndex, 2> gates_f_offsets() const {
|
||||
return {0, cell_size_ * 2};
|
||||
inline Eigen::array<Eigen::DenseIndex, 2> gates_f_offsets(
|
||||
const GateLayout gate_layout) const {
|
||||
return {0, gate_f_offset(gate_layout, cell_size_)};
|
||||
}
|
||||
|
||||
inline Eigen::array<Eigen::DenseIndex, 2> gates_o_offsets() const {
|
||||
@ -147,7 +159,7 @@ struct LSTMBlockCell {
|
||||
|
||||
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
|
||||
// GPUDevice implementation.
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||
struct LSTMBlockCellFprop : public LSTMBlockCell {
|
||||
LSTMBlockCellFprop(const int batch_size, const int input_size,
|
||||
const int cell_size)
|
||||
@ -172,7 +184,7 @@ struct LSTMBlockCellFprop : public LSTMBlockCell {
|
||||
|
||||
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
|
||||
// GPUDevice implementation.
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||
struct LSTMBlockCellBprop : public LSTMBlockCell {
|
||||
LSTMBlockCellBprop(const int batch_size, const int input_size,
|
||||
const int cell_size)
|
||||
@ -197,7 +209,7 @@ struct LSTMBlockCellBprop : public LSTMBlockCell {
|
||||
typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad);
|
||||
};
|
||||
|
||||
template <typename Device, typename T, bool USE_CUBLAS>
|
||||
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||
struct BlockLSTMBprop : public LSTMBlockCell {
|
||||
BlockLSTMBprop(const int batch_size, const int input_size,
|
||||
const int cell_size)
|
||||
@ -248,8 +260,8 @@ struct BlockLSTMBprop : public LSTMBlockCell {
|
||||
di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
|
||||
|
||||
dgates.slice(gates_i_offsets(), cell_extents()).device(d) = di;
|
||||
dgates.slice(gates_c_offsets(), cell_extents()).device(d) = dci;
|
||||
dgates.slice(gates_f_offsets(), cell_extents()).device(d) = df;
|
||||
dgates.slice(gates_c_offsets(gate_layout), cell_extents()).device(d) = dci;
|
||||
dgates.slice(gates_f_offsets(gate_layout), cell_extents()).device(d) = df;
|
||||
dgates.slice(gates_o_offsets(), cell_extents()).device(d) = do_;
|
||||
|
||||
cs_prev_grad.device(d) = dcs * f;
|
||||
|
@ -81,7 +81,7 @@ namespace {
|
||||
// Launch with blocks of (batch x 32)
|
||||
//
|
||||
// TODO(b/67600500): Try making 'use_peephole' a template parameter.
|
||||
template <typename T, bool use_peephole>
|
||||
template <typename T, bool use_peephole, GateLayout gate_layout>
|
||||
__global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev,
|
||||
const T* wci, const T* wcf, const T* wco, T* o, T* h,
|
||||
T* ci, T* cs, T* co, T* i, T* f,
|
||||
@ -156,18 +156,19 @@ __global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev,
|
||||
}
|
||||
i[cid] = i_local;
|
||||
|
||||
const T ci_local =
|
||||
tanh_op(gates[1 * cell_size + gid] + b[1 * cell_size + act_id]);
|
||||
const int c_offset = gate_c_offset(gate_layout, cell_size);
|
||||
const int f_offset = gate_f_offset(gate_layout, cell_size);
|
||||
|
||||
const T ci_local = tanh_op(gates[c_offset + gid] + b[c_offset + act_id]);
|
||||
ci[cid] = ci_local;
|
||||
|
||||
T f_local;
|
||||
if (use_peephole) {
|
||||
f_local =
|
||||
sigmoid_op(gates[2 * cell_size + gid] + b[2 * cell_size + act_id] +
|
||||
forget_bias_t + cs_prev[cid] * wcf[act_id]);
|
||||
f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] +
|
||||
forget_bias_t + cs_prev[cid] * wcf[act_id]);
|
||||
} else {
|
||||
f_local = sigmoid_op(gates[2 * cell_size + gid] +
|
||||
b[2 * cell_size + act_id] + forget_bias_t);
|
||||
f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] +
|
||||
forget_bias_t);
|
||||
}
|
||||
f[cid] = f_local;
|
||||
|
||||
@ -222,7 +223,7 @@ __global__ void concat_xh(T* xh, const T* x, const T* h_prev,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, GateLayout gate_layout>
|
||||
void LSTMBlockCellFpropWithCUDA(
|
||||
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,
|
||||
const float cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
||||
@ -267,20 +268,22 @@ void LSTMBlockCellFpropWithCUDA(
|
||||
|
||||
if (use_peephole) {
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
lstm_gates<T, true>, grid_dim_2d, block_dim_2d, 0, cu_stream,
|
||||
gates.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
|
||||
wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
|
||||
i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
|
||||
lstm_gates<T, true, gate_layout>, grid_dim_2d, block_dim_2d, 0,
|
||||
cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(),
|
||||
wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(),
|
||||
co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size,
|
||||
cell_size));
|
||||
} else {
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
lstm_gates<T, false>, grid_dim_2d, block_dim_2d, 0, cu_stream,
|
||||
gates.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
|
||||
wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
|
||||
i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
|
||||
lstm_gates<T, false, gate_layout>, grid_dim_2d, block_dim_2d, 0,
|
||||
cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(),
|
||||
wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(),
|
||||
co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size,
|
||||
cell_size));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, GateLayout gate_layout>
|
||||
__global__ void lstm_gates_bprop(
|
||||
const T* cs_prev, // [batch_size, cell_size]
|
||||
const T* h_prev, // [batch_size, cell_size]
|
||||
@ -347,8 +350,8 @@ __global__ void lstm_gates_bprop(
|
||||
di[cid] = di_local;
|
||||
|
||||
dgates[gid + 0 * cell_size] = di_local;
|
||||
dgates[gid + 1 * cell_size] = dci_local;
|
||||
dgates[gid + 2 * cell_size] = df_local;
|
||||
dgates[gate_c_offset(gate_layout, cell_size)] = dci_local;
|
||||
dgates[gate_f_offset(gate_layout, cell_size)] = df_local;
|
||||
dgates[gid + 3 * cell_size] = do_local;
|
||||
|
||||
cs_prev_grad[cid] = dcs_local * f_local;
|
||||
@ -357,7 +360,7 @@ __global__ void lstm_gates_bprop(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, GateLayout gate_layout>
|
||||
void LSTMBlockCellBpropWithCUDA(
|
||||
OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::ConstMatrix x,
|
||||
typename TTypes<T>::ConstMatrix cs_prev,
|
||||
@ -382,7 +385,7 @@ void LSTMBlockCellBpropWithCUDA(
|
||||
Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));
|
||||
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
lstm_gates_bprop<T>, grid_dim_2d, block_dim_2d, 0, cu_stream,
|
||||
lstm_gates_bprop<T, gate_layout>, grid_dim_2d, block_dim_2d, 0, cu_stream,
|
||||
cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(),
|
||||
wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(),
|
||||
co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(),
|
||||
@ -403,54 +406,59 @@ void LSTMBlockCellBpropWithCUDA(
|
||||
|
||||
} // namespace
|
||||
|
||||
#define DECLARE_GPU_FBPROP(T) \
|
||||
template <> \
|
||||
void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
|
||||
const float cell_clip, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
||||
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
||||
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
|
||||
LSTMBlockCellFpropWithCUDA<T>(ctx, d, forget_bias, cell_clip, \
|
||||
use_peephole, x, cs_prev, h_prev, w, wci, \
|
||||
wcf, wco, b, xh, i, cs, f, o, ci, co, gates, \
|
||||
h, batch_size_, cell_size_, input_size_); \
|
||||
} \
|
||||
template <> \
|
||||
void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
|
||||
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
|
||||
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
|
||||
typename TTypes<T>::ConstMatrix co, \
|
||||
typename TTypes<T>::ConstMatrix cs_grad, \
|
||||
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
|
||||
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
|
||||
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
|
||||
typename TTypes<T>::Matrix dgates, \
|
||||
typename TTypes<T>::Matrix cs_prev_grad, \
|
||||
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
||||
typename TTypes<T>::Vec wco_grad) { \
|
||||
LSTMBlockCellBpropWithCUDA<T>( \
|
||||
ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \
|
||||
cs_grad, h_grad, do_, dcs, dci, df, di, dgates, cs_prev_grad, \
|
||||
wci_grad, wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \
|
||||
} \
|
||||
template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>; \
|
||||
template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>; \
|
||||
template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */>;
|
||||
#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT) \
|
||||
template <> \
|
||||
void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>:: \
|
||||
operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
|
||||
const float cell_clip, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
||||
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
||||
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
|
||||
LSTMBlockCellFpropWithCUDA<T, GATE_LAYOUT>( \
|
||||
ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, h_prev, w, \
|
||||
wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h, batch_size_, \
|
||||
cell_size_, input_size_); \
|
||||
} \
|
||||
template <> \
|
||||
void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>:: \
|
||||
operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
||||
typename TTypes<T>::ConstMatrix x, \
|
||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||
typename TTypes<T>::ConstMatrix h_prev, \
|
||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
|
||||
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
|
||||
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
|
||||
typename TTypes<T>::ConstMatrix co, \
|
||||
typename TTypes<T>::ConstMatrix cs_grad, \
|
||||
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
|
||||
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
|
||||
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
|
||||
typename TTypes<T>::Matrix dgates, \
|
||||
typename TTypes<T>::Matrix cs_prev_grad, \
|
||||
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
||||
typename TTypes<T>::Vec wco_grad) { \
|
||||
LSTMBlockCellBpropWithCUDA<T, GATE_LAYOUT>( \
|
||||
ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \
|
||||
cs_grad, h_grad, do_, dcs, dci, df, di, dgates, cs_prev_grad, \
|
||||
wci_grad, wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \
|
||||
} \
|
||||
template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */, \
|
||||
GATE_LAYOUT>; \
|
||||
template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */, \
|
||||
GATE_LAYOUT>; \
|
||||
template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */, \
|
||||
GATE_LAYOUT>;
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
template struct TensorZero<GPUDevice, T>; \
|
||||
@ -460,10 +468,10 @@ void LSTMBlockCellBpropWithCUDA(
|
||||
template struct TensorCopyToUnaligned<GPUDevice, T>; \
|
||||
template struct TensorAdd<GPUDevice, T>; \
|
||||
\
|
||||
DECLARE_GPU_FBPROP(T);
|
||||
DECLARE_GPU_FBPROP(T, ICFO);
|
||||
|
||||
DECLARE_GPU_SPECS(float);
|
||||
DECLARE_GPU_SPECS(Eigen::half);
|
||||
DECLARE_GPU_SPECS(float);
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_FBPROP
|
||||
} // end namespace functor
|
||||
|
Loading…
x
Reference in New Issue
Block a user