BlockLSTM and LSTMBlockCell are now generic wrt to the gate layout
The default layout is still ICFO, but V2 version of the op will use IFCO to match CuDNN-RNN. PiperOrigin-RevId: 262538325
This commit is contained in:
parent
626d98bf78
commit
ed0eee6898
@ -41,7 +41,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, GateLayout gate_layout>
|
||||||
void LSTMBlockCellFpropWithEigen(
|
void LSTMBlockCellFpropWithEigen(
|
||||||
const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
|
const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d,
|
||||||
const float forget_bias, const float cell_clip, bool use_peephole,
|
const float forget_bias, const float cell_clip, bool use_peephole,
|
||||||
@ -83,18 +83,21 @@ void LSTMBlockCellFpropWithEigen(
|
|||||||
|
|
||||||
// Cell input.
|
// Cell input.
|
||||||
ci.device(d) =
|
ci.device(d) =
|
||||||
gates.slice(cell.gates_c_offsets(), cell.cell_extents()).tanh();
|
gates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents())
|
||||||
|
.tanh();
|
||||||
|
|
||||||
// Forget gate (w/ bias).
|
// Forget gate (w/ bias).
|
||||||
if (use_peephole) {
|
if (use_peephole) {
|
||||||
auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
|
auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
|
||||||
f.device(d) = (gates.slice(cell.gates_f_offsets(), cell.cell_extents()) +
|
f.device(d) =
|
||||||
f.constant(T(forget_bias)) + f_peep)
|
(gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) +
|
||||||
.sigmoid();
|
f.constant(T(forget_bias)) + f_peep)
|
||||||
|
.sigmoid();
|
||||||
} else {
|
} else {
|
||||||
f.device(d) = (gates.slice(cell.gates_f_offsets(), cell.cell_extents()) +
|
f.device(d) =
|
||||||
f.constant(T(forget_bias)))
|
(gates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents()) +
|
||||||
.sigmoid();
|
f.constant(T(forget_bias)))
|
||||||
|
.sigmoid();
|
||||||
}
|
}
|
||||||
|
|
||||||
// cs = ci .* i + f .* cs_prev
|
// cs = ci .* i + f .* cs_prev
|
||||||
@ -123,7 +126,7 @@ void LSTMBlockCellFpropWithEigen(
|
|||||||
h.device(d) = o * co;
|
h.device(d) = o * co;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T, GateLayout gate_layout>
|
||||||
void LSTMBlockCellBpropWithEigen(
|
void LSTMBlockCellBpropWithEigen(
|
||||||
const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
|
const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d,
|
||||||
bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
||||||
@ -164,8 +167,10 @@ void LSTMBlockCellBpropWithEigen(
|
|||||||
di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
|
di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
|
||||||
|
|
||||||
dgates.slice(cell.gates_i_offsets(), cell.cell_extents()).device(d) = di;
|
dgates.slice(cell.gates_i_offsets(), cell.cell_extents()).device(d) = di;
|
||||||
dgates.slice(cell.gates_c_offsets(), cell.cell_extents()).device(d) = dci;
|
dgates.slice(cell.gates_c_offsets(gate_layout), cell.cell_extents())
|
||||||
dgates.slice(cell.gates_f_offsets(), cell.cell_extents()).device(d) = df;
|
.device(d) = dci;
|
||||||
|
dgates.slice(cell.gates_f_offsets(gate_layout), cell.cell_extents())
|
||||||
|
.device(d) = df;
|
||||||
dgates.slice(cell.gates_o_offsets(), cell.cell_extents()).device(d) = do_;
|
dgates.slice(cell.gates_o_offsets(), cell.cell_extents()).device(d) = do_;
|
||||||
|
|
||||||
cs_prev_grad.device(d) = dcs * f;
|
cs_prev_grad.device(d) = dcs * f;
|
||||||
@ -179,54 +184,58 @@ void LSTMBlockCellBpropWithEigen(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECLARE_CPU_FBPROP(T) \
|
#define DECLARE_CPU_FBPROP(T, GATE_LAYOUT) \
|
||||||
template <> \
|
template <> \
|
||||||
void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \
|
void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */, GATE_LAYOUT>:: \
|
||||||
OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \
|
operator()( \
|
||||||
const float cell_clip, bool use_peephole, \
|
OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \
|
||||||
typename TTypes<T>::ConstMatrix x, \
|
const float cell_clip, bool use_peephole, \
|
||||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
typename TTypes<T>::ConstMatrix x, \
|
||||||
typename TTypes<T>::ConstMatrix h_prev, \
|
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
typename TTypes<T>::ConstMatrix h_prev, \
|
||||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||||
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
||||||
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
||||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
||||||
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
|
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||||
LSTMBlockCellFpropWithEigen<T>( \
|
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
|
||||||
*this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \
|
LSTMBlockCellFpropWithEigen<T, GATE_LAYOUT>( \
|
||||||
h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h); \
|
*this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \
|
||||||
} \
|
h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h); \
|
||||||
template <> \
|
} \
|
||||||
void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \
|
template <> \
|
||||||
OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \
|
void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */, GATE_LAYOUT>:: \
|
||||||
typename TTypes<T>::ConstMatrix x, \
|
operator()( \
|
||||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \
|
||||||
typename TTypes<T>::ConstMatrix h_prev, \
|
typename TTypes<T>::ConstMatrix x, \
|
||||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
typename TTypes<T>::ConstMatrix h_prev, \
|
||||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
|
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||||
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
|
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||||
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
|
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
|
||||||
typename TTypes<T>::ConstMatrix co, \
|
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
|
||||||
typename TTypes<T>::ConstMatrix cs_grad, \
|
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
|
||||||
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
|
typename TTypes<T>::ConstMatrix co, \
|
||||||
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
|
typename TTypes<T>::ConstMatrix cs_grad, \
|
||||||
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
|
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
|
||||||
typename TTypes<T>::Matrix dgates, \
|
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
|
||||||
typename TTypes<T>::Matrix cs_prev_grad, \
|
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
|
||||||
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
typename TTypes<T>::Matrix dgates, \
|
||||||
typename TTypes<T>::Vec wco_grad) { \
|
typename TTypes<T>::Matrix cs_prev_grad, \
|
||||||
LSTMBlockCellBpropWithEigen<CPUDevice, T>( \
|
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
||||||
*this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
|
typename TTypes<T>::Vec wco_grad) { \
|
||||||
i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dgates, \
|
LSTMBlockCellBpropWithEigen<CPUDevice, T, GATE_LAYOUT>( \
|
||||||
cs_prev_grad, wci_grad, wcf_grad, wco_grad); \
|
*this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \
|
||||||
} \
|
i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dgates, \
|
||||||
template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>; \
|
cs_prev_grad, wci_grad, wcf_grad, wco_grad); \
|
||||||
template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>;
|
} \
|
||||||
|
template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */, \
|
||||||
|
GATE_LAYOUT>; \
|
||||||
|
template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */, \
|
||||||
|
GATE_LAYOUT>;
|
||||||
|
|
||||||
#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T);
|
#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T, ICFO);
|
||||||
|
|
||||||
DECLARE_CPU_SPECS(Eigen::half);
|
DECLARE_CPU_SPECS(Eigen::half);
|
||||||
DECLARE_CPU_SPECS(float);
|
DECLARE_CPU_SPECS(float);
|
||||||
@ -234,9 +243,9 @@ DECLARE_CPU_SPECS(float);
|
|||||||
#undef DECLARE_CPU_FBPROP
|
#undef DECLARE_CPU_FBPROP
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#define DECLARE_GPU_FBPROP(T) \
|
#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT) \
|
||||||
template <> \
|
template <> \
|
||||||
void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
|
void LSTMBlockCellFprop<GPUDevice, T, true, GATE_LAYOUT>::operator()( \
|
||||||
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
|
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
|
||||||
const float cell_clip, bool use_peephole, \
|
const float cell_clip, bool use_peephole, \
|
||||||
typename TTypes<T>::ConstMatrix x, \
|
typename TTypes<T>::ConstMatrix x, \
|
||||||
@ -250,7 +259,7 @@ DECLARE_CPU_SPECS(float);
|
|||||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||||
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h); \
|
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h); \
|
||||||
template <> \
|
template <> \
|
||||||
void LSTMBlockCellBprop<GPUDevice, T, true>::operator()( \
|
void LSTMBlockCellBprop<GPUDevice, T, true, GATE_LAYOUT>::operator()( \
|
||||||
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
||||||
typename TTypes<T>::ConstMatrix x, \
|
typename TTypes<T>::ConstMatrix x, \
|
||||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||||
@ -270,21 +279,20 @@ DECLARE_CPU_SPECS(float);
|
|||||||
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
||||||
typename TTypes<T>::Vec wco_grad); \
|
typename TTypes<T>::Vec wco_grad); \
|
||||||
\
|
\
|
||||||
extern template struct LSTMBlockCellBprop<GPUDevice, T, \
|
extern template struct LSTMBlockCellBprop< \
|
||||||
true /* USE_CUBLAS */>; \
|
GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>; \
|
||||||
extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
|
extern template struct LSTMBlockCellFprop<GPUDevice, T, true, GATE_LAYOUT>;
|
||||||
|
|
||||||
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_FBPROP(T);
|
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_FBPROP(T, ICFO);
|
||||||
|
|
||||||
DECLARE_GPU_SPECS(Eigen::half);
|
|
||||||
DECLARE_GPU_SPECS(float);
|
DECLARE_GPU_SPECS(float);
|
||||||
|
DECLARE_GPU_SPECS(Eigen::half);
|
||||||
#undef DECLARE_GPU_SPECS
|
#undef DECLARE_GPU_SPECS
|
||||||
#undef DECLARE_GPU_FBPROP
|
#undef DECLARE_GPU_FBROP
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
template <typename Device, typename T, bool USE_CUBLAS>
|
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||||
class LSTMBlockCellOp : public OpKernel {
|
class LSTMBlockCellOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
@ -406,8 +414,8 @@ class LSTMBlockCellOp : public OpKernel {
|
|||||||
|
|
||||||
const Device& device = ctx->eigen_device<Device>();
|
const Device& device = ctx->eigen_device<Device>();
|
||||||
|
|
||||||
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
|
||||||
cell_size)(
|
batch_size, input_size, cell_size)(
|
||||||
ctx, device, forget_bias_, cell_clip_, use_peephole_,
|
ctx, device, forget_bias_, cell_clip_, use_peephole_,
|
||||||
x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
|
x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
|
||||||
h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
|
h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
|
||||||
@ -427,7 +435,7 @@ class LSTMBlockCellOp : public OpKernel {
|
|||||||
#define REGISTER_KERNEL(T) \
|
#define REGISTER_KERNEL(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
LSTMBlockCellOp<CPUDevice, T, false>);
|
LSTMBlockCellOp<CPUDevice, T, false, ICFO>);
|
||||||
|
|
||||||
REGISTER_KERNEL(Eigen::half);
|
REGISTER_KERNEL(Eigen::half);
|
||||||
REGISTER_KERNEL(float);
|
REGISTER_KERNEL(float);
|
||||||
@ -437,14 +445,14 @@ REGISTER_KERNEL(float);
|
|||||||
#define REGISTER_GPU_KERNEL(T) \
|
#define REGISTER_GPU_KERNEL(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||||
LSTMBlockCellOp<GPUDevice, T, true>);
|
LSTMBlockCellOp<GPUDevice, T, true, ICFO>);
|
||||||
|
|
||||||
REGISTER_GPU_KERNEL(Eigen::half);
|
REGISTER_GPU_KERNEL(Eigen::half);
|
||||||
REGISTER_GPU_KERNEL(float);
|
REGISTER_GPU_KERNEL(float);
|
||||||
#undef REGISTER_GPU_KERNEL
|
#undef REGISTER_GPU_KERNEL
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
template <typename Device, typename T, bool USE_CUBLAS>
|
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||||
class LSTMBlockCellGradOp : public OpKernel {
|
class LSTMBlockCellGradOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
@ -668,8 +676,8 @@ class LSTMBlockCellGradOp : public OpKernel {
|
|||||||
functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
|
functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>());
|
||||||
functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
|
functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>());
|
||||||
|
|
||||||
functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS, gate_layout>(
|
||||||
cell_size)(
|
batch_size, input_size, cell_size)(
|
||||||
ctx, device, use_peephole_, x_tensor->matrix<T>(),
|
ctx, device, use_peephole_, x_tensor->matrix<T>(),
|
||||||
cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
|
cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
|
||||||
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
|
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
|
||||||
@ -691,7 +699,7 @@ class LSTMBlockCellGradOp : public OpKernel {
|
|||||||
#define REGISTER_KERNEL(T) \
|
#define REGISTER_KERNEL(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
LSTMBlockCellGradOp<CPUDevice, T, false>);
|
LSTMBlockCellGradOp<CPUDevice, T, false, ICFO>);
|
||||||
REGISTER_KERNEL(float);
|
REGISTER_KERNEL(float);
|
||||||
REGISTER_KERNEL(Eigen::half);
|
REGISTER_KERNEL(Eigen::half);
|
||||||
#undef REGISTER_KERNEL
|
#undef REGISTER_KERNEL
|
||||||
@ -700,7 +708,7 @@ REGISTER_KERNEL(Eigen::half);
|
|||||||
#define REGISTER_GPU_KERNEL(T) \
|
#define REGISTER_GPU_KERNEL(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||||
LSTMBlockCellGradOp<GPUDevice, T, true>);
|
LSTMBlockCellGradOp<GPUDevice, T, true, ICFO>);
|
||||||
|
|
||||||
REGISTER_GPU_KERNEL(Eigen::half);
|
REGISTER_GPU_KERNEL(Eigen::half);
|
||||||
REGISTER_GPU_KERNEL(float);
|
REGISTER_GPU_KERNEL(float);
|
||||||
@ -815,7 +823,7 @@ class SliceHelper {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
template <typename Device, typename T, bool USE_CUBLAS>
|
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||||
class BlockLSTMOp : public OpKernel {
|
class BlockLSTMOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
@ -972,8 +980,8 @@ class BlockLSTMOp : public OpKernel {
|
|||||||
Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
|
Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
|
||||||
Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
|
Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
|
||||||
|
|
||||||
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
|
||||||
cell_size)(
|
batch_size, input_size, cell_size)(
|
||||||
ctx, device, forget_bias_, cell_clip_, use_peephole_,
|
ctx, device, forget_bias_, cell_clip_, use_peephole_,
|
||||||
x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
|
x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
|
||||||
h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
|
h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
|
||||||
@ -1005,7 +1013,7 @@ class BlockLSTMOp : public OpKernel {
|
|||||||
#define REGISTER_KERNEL(T) \
|
#define REGISTER_KERNEL(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
BlockLSTMOp<CPUDevice, T, false>);
|
BlockLSTMOp<CPUDevice, T, false, ICFO>);
|
||||||
|
|
||||||
REGISTER_KERNEL(Eigen::half);
|
REGISTER_KERNEL(Eigen::half);
|
||||||
REGISTER_KERNEL(float);
|
REGISTER_KERNEL(float);
|
||||||
@ -1036,14 +1044,14 @@ DECLARE_GPU_SPECS(float);
|
|||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
.HostMemory("seq_len_max") \
|
.HostMemory("seq_len_max") \
|
||||||
.TypeConstraint<T>("T"), \
|
.TypeConstraint<T>("T"), \
|
||||||
BlockLSTMOp<GPUDevice, T, true>);
|
BlockLSTMOp<GPUDevice, T, true, ICFO>);
|
||||||
|
|
||||||
REGISTER_GPU_KERNEL(Eigen::half);
|
REGISTER_GPU_KERNEL(Eigen::half);
|
||||||
REGISTER_GPU_KERNEL(float);
|
REGISTER_GPU_KERNEL(float);
|
||||||
#undef REGISTER_GPU_KERNEL
|
#undef REGISTER_GPU_KERNEL
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
template <typename Device, typename T, bool USE_CUBLAS>
|
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||||
class BlockLSTMGradOp : public OpKernel {
|
class BlockLSTMGradOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
@ -1246,8 +1254,8 @@ class BlockLSTMGradOp : public OpKernel {
|
|||||||
const Tensor& const_h_grad_tensor = h_grad_tensor;
|
const Tensor& const_h_grad_tensor = h_grad_tensor;
|
||||||
|
|
||||||
Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
|
Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
|
||||||
functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
|
functor::BlockLSTMBprop<Device, T, USE_CUBLAS, gate_layout>(
|
||||||
cell_size)(
|
batch_size, input_size, cell_size)(
|
||||||
ctx, device, use_peephole_, x_tensor.matrix<T>(),
|
ctx, device, use_peephole_, x_tensor.matrix<T>(),
|
||||||
cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
|
cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
|
||||||
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
|
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
|
||||||
@ -1279,7 +1287,7 @@ class BlockLSTMGradOp : public OpKernel {
|
|||||||
#define REGISTER_KERNEL(T) \
|
#define REGISTER_KERNEL(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
BlockLSTMGradOp<CPUDevice, T, false>);
|
BlockLSTMGradOp<CPUDevice, T, false, ICFO>);
|
||||||
|
|
||||||
REGISTER_KERNEL(Eigen::half);
|
REGISTER_KERNEL(Eigen::half);
|
||||||
REGISTER_KERNEL(float);
|
REGISTER_KERNEL(float);
|
||||||
@ -1287,9 +1295,9 @@ REGISTER_KERNEL(float);
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
namespace functor {
|
namespace functor {
|
||||||
#define DECLARE_GPU_BPROP(T) \
|
#define DECLARE_GPU_BPROP(T, GATE_LAYOUT) \
|
||||||
template <> \
|
template <> \
|
||||||
void BlockLSTMBprop<GPUDevice, T, true>::operator()( \
|
void BlockLSTMBprop<GPUDevice, T, true, GATE_LAYOUT>::operator()( \
|
||||||
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
||||||
typename TTypes<T>::ConstMatrix x, \
|
typename TTypes<T>::ConstMatrix x, \
|
||||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||||
@ -1311,7 +1319,7 @@ namespace functor {
|
|||||||
typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, \
|
typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, \
|
||||||
typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, \
|
typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, \
|
||||||
typename TTypes<T>::Vec b_grad); \
|
typename TTypes<T>::Vec b_grad); \
|
||||||
extern template struct BlockLSTMBprop<GPUDevice, T, true>;
|
extern template struct BlockLSTMBprop<GPUDevice, T, true, GATE_LAYOUT>;
|
||||||
|
|
||||||
#define DECLARE_GPU_SPECS(T) \
|
#define DECLARE_GPU_SPECS(T) \
|
||||||
template <> \
|
template <> \
|
||||||
@ -1337,7 +1345,7 @@ namespace functor {
|
|||||||
extern template struct TensorCopy<GPUDevice, T>; \
|
extern template struct TensorCopy<GPUDevice, T>; \
|
||||||
extern template struct TensorAdd<GPUDevice, T>; \
|
extern template struct TensorAdd<GPUDevice, T>; \
|
||||||
\
|
\
|
||||||
DECLARE_GPU_BPROP(T);
|
DECLARE_GPU_BPROP(T, ICFO);
|
||||||
|
|
||||||
DECLARE_GPU_SPECS(Eigen::half);
|
DECLARE_GPU_SPECS(Eigen::half);
|
||||||
DECLARE_GPU_SPECS(float);
|
DECLARE_GPU_SPECS(float);
|
||||||
@ -1350,7 +1358,7 @@ DECLARE_GPU_SPECS(float);
|
|||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
.HostMemory("seq_len_max") \
|
.HostMemory("seq_len_max") \
|
||||||
.TypeConstraint<T>("T"), \
|
.TypeConstraint<T>("T"), \
|
||||||
BlockLSTMGradOp<GPUDevice, T, true>);
|
BlockLSTMGradOp<GPUDevice, T, true, ICFO>);
|
||||||
|
|
||||||
REGISTER_GPU_KERNEL(Eigen::half);
|
REGISTER_GPU_KERNEL(Eigen::half);
|
||||||
REGISTER_GPU_KERNEL(float);
|
REGISTER_GPU_KERNEL(float);
|
||||||
|
@ -25,6 +25,16 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
class OpKernelContext;
|
class OpKernelContext;
|
||||||
|
|
||||||
|
enum GateLayout { ICFO, IFCO };
|
||||||
|
|
||||||
|
constexpr int gate_c_offset(GateLayout gate_layout, int cell_size) {
|
||||||
|
return (gate_layout == ICFO) ? cell_size : cell_size * 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int gate_f_offset(GateLayout gate_layout, int cell_size) {
|
||||||
|
return (gate_layout == ICFO) ? cell_size * 2 : cell_size;
|
||||||
|
}
|
||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
@ -107,12 +117,14 @@ struct LSTMBlockCell {
|
|||||||
return {0, 0};
|
return {0, 0};
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Eigen::array<Eigen::DenseIndex, 2> gates_c_offsets() const {
|
inline Eigen::array<Eigen::DenseIndex, 2> gates_c_offsets(
|
||||||
return {0, cell_size_};
|
const GateLayout gate_layout) const {
|
||||||
|
return {0, gate_c_offset(gate_layout, cell_size_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Eigen::array<Eigen::DenseIndex, 2> gates_f_offsets() const {
|
inline Eigen::array<Eigen::DenseIndex, 2> gates_f_offsets(
|
||||||
return {0, cell_size_ * 2};
|
const GateLayout gate_layout) const {
|
||||||
|
return {0, gate_f_offset(gate_layout, cell_size_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Eigen::array<Eigen::DenseIndex, 2> gates_o_offsets() const {
|
inline Eigen::array<Eigen::DenseIndex, 2> gates_o_offsets() const {
|
||||||
@ -147,7 +159,7 @@ struct LSTMBlockCell {
|
|||||||
|
|
||||||
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
|
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
|
||||||
// GPUDevice implementation.
|
// GPUDevice implementation.
|
||||||
template <typename Device, typename T, bool USE_CUBLAS>
|
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||||
struct LSTMBlockCellFprop : public LSTMBlockCell {
|
struct LSTMBlockCellFprop : public LSTMBlockCell {
|
||||||
LSTMBlockCellFprop(const int batch_size, const int input_size,
|
LSTMBlockCellFprop(const int batch_size, const int input_size,
|
||||||
const int cell_size)
|
const int cell_size)
|
||||||
@ -172,7 +184,7 @@ struct LSTMBlockCellFprop : public LSTMBlockCell {
|
|||||||
|
|
||||||
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
|
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
|
||||||
// GPUDevice implementation.
|
// GPUDevice implementation.
|
||||||
template <typename Device, typename T, bool USE_CUBLAS>
|
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||||
struct LSTMBlockCellBprop : public LSTMBlockCell {
|
struct LSTMBlockCellBprop : public LSTMBlockCell {
|
||||||
LSTMBlockCellBprop(const int batch_size, const int input_size,
|
LSTMBlockCellBprop(const int batch_size, const int input_size,
|
||||||
const int cell_size)
|
const int cell_size)
|
||||||
@ -197,7 +209,7 @@ struct LSTMBlockCellBprop : public LSTMBlockCell {
|
|||||||
typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad);
|
typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T, bool USE_CUBLAS>
|
template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||||
struct BlockLSTMBprop : public LSTMBlockCell {
|
struct BlockLSTMBprop : public LSTMBlockCell {
|
||||||
BlockLSTMBprop(const int batch_size, const int input_size,
|
BlockLSTMBprop(const int batch_size, const int input_size,
|
||||||
const int cell_size)
|
const int cell_size)
|
||||||
@ -248,8 +260,8 @@ struct BlockLSTMBprop : public LSTMBlockCell {
|
|||||||
di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
|
di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
|
||||||
|
|
||||||
dgates.slice(gates_i_offsets(), cell_extents()).device(d) = di;
|
dgates.slice(gates_i_offsets(), cell_extents()).device(d) = di;
|
||||||
dgates.slice(gates_c_offsets(), cell_extents()).device(d) = dci;
|
dgates.slice(gates_c_offsets(gate_layout), cell_extents()).device(d) = dci;
|
||||||
dgates.slice(gates_f_offsets(), cell_extents()).device(d) = df;
|
dgates.slice(gates_f_offsets(gate_layout), cell_extents()).device(d) = df;
|
||||||
dgates.slice(gates_o_offsets(), cell_extents()).device(d) = do_;
|
dgates.slice(gates_o_offsets(), cell_extents()).device(d) = do_;
|
||||||
|
|
||||||
cs_prev_grad.device(d) = dcs * f;
|
cs_prev_grad.device(d) = dcs * f;
|
||||||
|
@ -81,7 +81,7 @@ namespace {
|
|||||||
// Launch with blocks of (batch x 32)
|
// Launch with blocks of (batch x 32)
|
||||||
//
|
//
|
||||||
// TODO(b/67600500): Try making 'use_peephole' a template parameter.
|
// TODO(b/67600500): Try making 'use_peephole' a template parameter.
|
||||||
template <typename T, bool use_peephole>
|
template <typename T, bool use_peephole, GateLayout gate_layout>
|
||||||
__global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev,
|
__global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev,
|
||||||
const T* wci, const T* wcf, const T* wco, T* o, T* h,
|
const T* wci, const T* wcf, const T* wco, T* o, T* h,
|
||||||
T* ci, T* cs, T* co, T* i, T* f,
|
T* ci, T* cs, T* co, T* i, T* f,
|
||||||
@ -156,18 +156,19 @@ __global__ void lstm_gates(const T* gates, const T* b, const T* cs_prev,
|
|||||||
}
|
}
|
||||||
i[cid] = i_local;
|
i[cid] = i_local;
|
||||||
|
|
||||||
const T ci_local =
|
const int c_offset = gate_c_offset(gate_layout, cell_size);
|
||||||
tanh_op(gates[1 * cell_size + gid] + b[1 * cell_size + act_id]);
|
const int f_offset = gate_f_offset(gate_layout, cell_size);
|
||||||
|
|
||||||
|
const T ci_local = tanh_op(gates[c_offset + gid] + b[c_offset + act_id]);
|
||||||
ci[cid] = ci_local;
|
ci[cid] = ci_local;
|
||||||
|
|
||||||
T f_local;
|
T f_local;
|
||||||
if (use_peephole) {
|
if (use_peephole) {
|
||||||
f_local =
|
f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] +
|
||||||
sigmoid_op(gates[2 * cell_size + gid] + b[2 * cell_size + act_id] +
|
forget_bias_t + cs_prev[cid] * wcf[act_id]);
|
||||||
forget_bias_t + cs_prev[cid] * wcf[act_id]);
|
|
||||||
} else {
|
} else {
|
||||||
f_local = sigmoid_op(gates[2 * cell_size + gid] +
|
f_local = sigmoid_op(gates[f_offset + gid] + b[f_offset + act_id] +
|
||||||
b[2 * cell_size + act_id] + forget_bias_t);
|
forget_bias_t);
|
||||||
}
|
}
|
||||||
f[cid] = f_local;
|
f[cid] = f_local;
|
||||||
|
|
||||||
@ -222,7 +223,7 @@ __global__ void concat_xh(T* xh, const T* x, const T* h_prev,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, GateLayout gate_layout>
|
||||||
void LSTMBlockCellFpropWithCUDA(
|
void LSTMBlockCellFpropWithCUDA(
|
||||||
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,
|
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias,
|
||||||
const float cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
const float cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
|
||||||
@ -267,20 +268,22 @@ void LSTMBlockCellFpropWithCUDA(
|
|||||||
|
|
||||||
if (use_peephole) {
|
if (use_peephole) {
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
lstm_gates<T, true>, grid_dim_2d, block_dim_2d, 0, cu_stream,
|
lstm_gates<T, true, gate_layout>, grid_dim_2d, block_dim_2d, 0,
|
||||||
gates.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
|
cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(),
|
||||||
wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
|
wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(),
|
||||||
i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
|
co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size,
|
||||||
|
cell_size));
|
||||||
} else {
|
} else {
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
lstm_gates<T, false>, grid_dim_2d, block_dim_2d, 0, cu_stream,
|
lstm_gates<T, false, gate_layout>, grid_dim_2d, block_dim_2d, 0,
|
||||||
gates.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(),
|
cu_stream, gates.data(), b.data(), cs_prev.data(), wci.data(),
|
||||||
wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(),
|
wcf.data(), wco.data(), o.data(), h.data(), ci.data(), cs.data(),
|
||||||
i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size));
|
co.data(), i.data(), f.data(), forget_bias, cell_clip, batch_size,
|
||||||
|
cell_size));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, GateLayout gate_layout>
|
||||||
__global__ void lstm_gates_bprop(
|
__global__ void lstm_gates_bprop(
|
||||||
const T* cs_prev, // [batch_size, cell_size]
|
const T* cs_prev, // [batch_size, cell_size]
|
||||||
const T* h_prev, // [batch_size, cell_size]
|
const T* h_prev, // [batch_size, cell_size]
|
||||||
@ -347,8 +350,8 @@ __global__ void lstm_gates_bprop(
|
|||||||
di[cid] = di_local;
|
di[cid] = di_local;
|
||||||
|
|
||||||
dgates[gid + 0 * cell_size] = di_local;
|
dgates[gid + 0 * cell_size] = di_local;
|
||||||
dgates[gid + 1 * cell_size] = dci_local;
|
dgates[gate_c_offset(gate_layout, cell_size)] = dci_local;
|
||||||
dgates[gid + 2 * cell_size] = df_local;
|
dgates[gate_f_offset(gate_layout, cell_size)] = df_local;
|
||||||
dgates[gid + 3 * cell_size] = do_local;
|
dgates[gid + 3 * cell_size] = do_local;
|
||||||
|
|
||||||
cs_prev_grad[cid] = dcs_local * f_local;
|
cs_prev_grad[cid] = dcs_local * f_local;
|
||||||
@ -357,7 +360,7 @@ __global__ void lstm_gates_bprop(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, GateLayout gate_layout>
|
||||||
void LSTMBlockCellBpropWithCUDA(
|
void LSTMBlockCellBpropWithCUDA(
|
||||||
OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::ConstMatrix x,
|
OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T>::ConstMatrix x,
|
||||||
typename TTypes<T>::ConstMatrix cs_prev,
|
typename TTypes<T>::ConstMatrix cs_prev,
|
||||||
@ -382,7 +385,7 @@ void LSTMBlockCellBpropWithCUDA(
|
|||||||
Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));
|
Eigen::divup(cell_size, static_cast<int>(block_dim_2d.y)));
|
||||||
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
lstm_gates_bprop<T>, grid_dim_2d, block_dim_2d, 0, cu_stream,
|
lstm_gates_bprop<T, gate_layout>, grid_dim_2d, block_dim_2d, 0, cu_stream,
|
||||||
cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(),
|
cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(),
|
||||||
wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(),
|
wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(),
|
||||||
co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(),
|
co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(),
|
||||||
@ -403,54 +406,59 @@ void LSTMBlockCellBpropWithCUDA(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#define DECLARE_GPU_FBPROP(T) \
|
#define DECLARE_GPU_FBPROP(T, GATE_LAYOUT) \
|
||||||
template <> \
|
template <> \
|
||||||
void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()( \
|
void LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>:: \
|
||||||
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
|
operator()( \
|
||||||
const float cell_clip, bool use_peephole, \
|
OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \
|
||||||
typename TTypes<T>::ConstMatrix x, \
|
const float cell_clip, bool use_peephole, \
|
||||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
typename TTypes<T>::ConstMatrix x, \
|
||||||
typename TTypes<T>::ConstMatrix h_prev, \
|
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
typename TTypes<T>::ConstMatrix h_prev, \
|
||||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||||
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
|
||||||
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
|
||||||
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
|
||||||
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
|
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
|
||||||
LSTMBlockCellFpropWithCUDA<T>(ctx, d, forget_bias, cell_clip, \
|
typename TTypes<T>::Matrix gates, typename TTypes<T>::Matrix h) { \
|
||||||
use_peephole, x, cs_prev, h_prev, w, wci, \
|
LSTMBlockCellFpropWithCUDA<T, GATE_LAYOUT>( \
|
||||||
wcf, wco, b, xh, i, cs, f, o, ci, co, gates, \
|
ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, h_prev, w, \
|
||||||
h, batch_size_, cell_size_, input_size_); \
|
wci, wcf, wco, b, xh, i, cs, f, o, ci, co, gates, h, batch_size_, \
|
||||||
} \
|
cell_size_, input_size_); \
|
||||||
template <> \
|
} \
|
||||||
void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>::operator()( \
|
template <> \
|
||||||
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
void LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */, GATE_LAYOUT>:: \
|
||||||
typename TTypes<T>::ConstMatrix x, \
|
operator()( \
|
||||||
typename TTypes<T>::ConstMatrix cs_prev, \
|
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
|
||||||
typename TTypes<T>::ConstMatrix h_prev, \
|
typename TTypes<T>::ConstMatrix x, \
|
||||||
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
typename TTypes<T>::ConstMatrix cs_prev, \
|
||||||
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
typename TTypes<T>::ConstMatrix h_prev, \
|
||||||
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
|
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
|
||||||
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
|
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
|
||||||
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
|
typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \
|
||||||
typename TTypes<T>::ConstMatrix co, \
|
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \
|
||||||
typename TTypes<T>::ConstMatrix cs_grad, \
|
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \
|
||||||
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
|
typename TTypes<T>::ConstMatrix co, \
|
||||||
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
|
typename TTypes<T>::ConstMatrix cs_grad, \
|
||||||
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
|
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
|
||||||
typename TTypes<T>::Matrix dgates, \
|
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \
|
||||||
typename TTypes<T>::Matrix cs_prev_grad, \
|
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \
|
||||||
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
typename TTypes<T>::Matrix dgates, \
|
||||||
typename TTypes<T>::Vec wco_grad) { \
|
typename TTypes<T>::Matrix cs_prev_grad, \
|
||||||
LSTMBlockCellBpropWithCUDA<T>( \
|
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \
|
||||||
ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \
|
typename TTypes<T>::Vec wco_grad) { \
|
||||||
cs_grad, h_grad, do_, dcs, dci, df, di, dgates, cs_prev_grad, \
|
LSTMBlockCellBpropWithCUDA<T, GATE_LAYOUT>( \
|
||||||
wci_grad, wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \
|
ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \
|
||||||
} \
|
cs_grad, h_grad, do_, dcs, dci, df, di, dgates, cs_prev_grad, \
|
||||||
template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */>; \
|
wci_grad, wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \
|
||||||
template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */>; \
|
} \
|
||||||
template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */>;
|
template struct LSTMBlockCellFprop<GPUDevice, T, true /* USE_CUBLAS */, \
|
||||||
|
GATE_LAYOUT>; \
|
||||||
|
template struct LSTMBlockCellBprop<GPUDevice, T, true /* USE_CUBLAS */, \
|
||||||
|
GATE_LAYOUT>; \
|
||||||
|
template struct BlockLSTMBprop<GPUDevice, T, true /* USE_CUBLAS */, \
|
||||||
|
GATE_LAYOUT>;
|
||||||
|
|
||||||
#define DECLARE_GPU_SPECS(T) \
|
#define DECLARE_GPU_SPECS(T) \
|
||||||
template struct TensorZero<GPUDevice, T>; \
|
template struct TensorZero<GPUDevice, T>; \
|
||||||
@ -460,10 +468,10 @@ void LSTMBlockCellBpropWithCUDA(
|
|||||||
template struct TensorCopyToUnaligned<GPUDevice, T>; \
|
template struct TensorCopyToUnaligned<GPUDevice, T>; \
|
||||||
template struct TensorAdd<GPUDevice, T>; \
|
template struct TensorAdd<GPUDevice, T>; \
|
||||||
\
|
\
|
||||||
DECLARE_GPU_FBPROP(T);
|
DECLARE_GPU_FBPROP(T, ICFO);
|
||||||
|
|
||||||
DECLARE_GPU_SPECS(float);
|
|
||||||
DECLARE_GPU_SPECS(Eigen::half);
|
DECLARE_GPU_SPECS(Eigen::half);
|
||||||
|
DECLARE_GPU_SPECS(float);
|
||||||
#undef DECLARE_GPU_SPECS
|
#undef DECLARE_GPU_SPECS
|
||||||
#undef DECLARE_GPU_FBPROP
|
#undef DECLARE_GPU_FBPROP
|
||||||
} // end namespace functor
|
} // end namespace functor
|
||||||
|
Loading…
x
Reference in New Issue
Block a user