Removing stream to clean up code.

Change: 137403432
This commit is contained in:
A. Unique TensorFlower 2016-10-27 08:27:41 -08:00 committed by TensorFlower Gardener
parent 4ebf18a1e0
commit ef7ac603d0
6 changed files with 75 additions and 140 deletions

View File

@ -37,7 +37,6 @@ perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
namespace functor {
template <typename T>
void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
perftools::gputools::Stream* stream,
bool transa, bool transb, uint64 m,
uint64 n, uint64 k, T alpha, const T* a,
int lda, const T* b, int ldb, T beta, T* c,
@ -52,7 +51,8 @@ void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
auto c_ptr = AsDeviceMemory(c);
bool blas_launch_status =
stream
ctx->op_device_context()
->stream()
->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr,
lda, b_ptr, ldb, beta, &c_ptr, ldc)
.ok();

View File

@ -21,22 +21,15 @@ limitations under the License.
#include "tensorflow/core/kernels/eigen_activations.h"
#include "tensorflow/core/platform/types.h"
namespace perftools {
namespace gputools {
class Stream;
} // end namespace gputools
} // end namespace perftools
namespace tensorflow {
class OpKernelContext;
namespace functor {
template <typename T>
struct TensorCuBlasGemm {
void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
bool transa, bool transb, uint64 m, uint64 n, uint64 k,
T alpha, const T* a, int lda, const T* b, int ldb, T beta,
T* c, int ldc);
void operator()(OpKernelContext* ctx, bool transa, bool transb, uint64 m,
uint64 n, uint64 k, T alpha, const T* a, int lda, const T* b,
int ldb, T beta, T* c, int ldc);
};
template <typename Device, typename T, bool USE_CUBLAS>
@ -44,16 +37,15 @@ struct TensorBlasGemm;
template <typename Device, typename T>
struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, bool transa, bool transb, T alpha,
typename TTypes<T>::ConstMatrix a,
static void compute(OpKernelContext* ctx, const Device& d, bool transa,
bool transb, T alpha, typename TTypes<T>::ConstMatrix a,
typename TTypes<T>::ConstMatrix b, T beta,
typename TTypes<T>::Matrix c) {
int64 m = c.dimensions()[0];
int64 n = c.dimensions()[1];
int64 k = transa ? a.dimensions()[0] : a.dimensions()[1];
TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(),
TensorCuBlasGemm<T>()(ctx, transb, transa, n, m, k, alpha, b.data(),
transb ? k : n, a.data(), transa ? m : k, beta,
c.data(), n);
}
@ -61,9 +53,8 @@ struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
template <typename Device, typename T>
struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> {
static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, bool transa, bool transb, T alpha,
typename TTypes<T>::ConstMatrix a,
static void compute(OpKernelContext* ctx, const Device& d, bool transa,
bool transb, T alpha, typename TTypes<T>::ConstMatrix a,
typename TTypes<T>::ConstMatrix b, T beta,
typename TTypes<T>::Matrix c) {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;

View File

@ -15,10 +15,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/rnn/kernels/gru_ops.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -151,14 +147,9 @@ class GRUCellBlockOp : public OpKernel {
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::GRUBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_u_bar_tensor.matrix<T>(),
r_tensor->matrix<T>(), u_tensor->matrix<T>(), c_tensor->matrix<T>(),
@ -362,14 +353,10 @@ class GRUBlockCellGradOp : public OpKernel {
&d_x_component_2_h_prevr));
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::GRUBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(),
b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_tensor->matrix<T>(),
u_tensor->matrix<T>(), c_tensor->matrix<T>(), d_h_tensor->matrix<T>(),
@ -400,8 +387,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void GRUBlockCellFprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, typename TTypes<T>::ConstMatrix x, \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w_ru, \
typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \
@ -430,9 +417,9 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void GRUBlockCellBprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix h, typename TTypes<T>::ConstMatrix w_ru, \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix h, \
typename TTypes<T>::ConstMatrix w_ru, \
typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \
typename TTypes<T>::ConstVec b_c, typename TTypes<T>::ConstMatrix r, \
typename TTypes<T>::ConstMatrix u, typename TTypes<T>::ConstMatrix c, \

View File

@ -21,12 +21,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
namespace perftools {
namespace gputools {
class Stream;
} // end namespace gputools
} // end namespace perftools
namespace tensorflow {
class OpKernelContext;
@ -77,18 +71,15 @@ struct GRUBlockCellFprop : public GRUCell {
const int cell_size)
: GRUCell(batch_size, input_size, cell_size) {}
void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix h_prev,
typename TTypes<T>::ConstMatrix w_ru,
typename TTypes<T>::ConstMatrix w_c,
typename TTypes<T>::ConstVec b_ru,
typename TTypes<T>::ConstVec b_c,
typename TTypes<T>::Matrix r_u_bar,
typename TTypes<T>::Matrix r, typename TTypes<T>::Matrix u,
typename TTypes<T>::Matrix c, typename TTypes<T>::Matrix h,
typename TTypes<T>::Matrix x_h_prev,
typename TTypes<T>::Matrix x_h_prevr) {
void operator()(
OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix h_prev,
typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
typename TTypes<T>::Matrix r_u_bar, typename TTypes<T>::Matrix r,
typename TTypes<T>::Matrix u, typename TTypes<T>::Matrix c,
typename TTypes<T>::Matrix h, typename TTypes<T>::Matrix x_h_prev,
typename TTypes<T>::Matrix x_h_prevr) {
// Concat x_h_prev = [x, h_prev].
x_h_prev.slice(x_offsets(), x_extends()).device(d) = x;
x_h_prev.slice(h_offsets(), h_extends()).device(d) = h_prev;
@ -96,9 +87,8 @@ struct GRUBlockCellFprop : public GRUCell {
// r_u_bar = x_h_prev * w_ru + b_ru
typename TTypes<T>::ConstMatrix const_x_h_prev(x_h_prev.data(),
x_h_prev.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, stream, d, false, false,
T(1), const_x_h_prev, w_ru,
T(0), r_u_bar);
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, d, false, false, T(1), const_x_h_prev, w_ru, T(0), r_u_bar);
// Creating a bias matrix for adding by broadcasting 'b_ru'
Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
@ -117,7 +107,7 @@ struct GRUBlockCellFprop : public GRUCell {
typename TTypes<T>::ConstMatrix const_x_h_prevr(x_h_prevr.data(),
x_h_prevr.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c);
ctx, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c);
Eigen::array<Eigen::DenseIndex, 2> b_c_shape({1, b_c.dimensions()[0]});
c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape));
@ -135,8 +125,7 @@ struct GRUBlockCellBprop : public GRUCell {
: GRUCell(batch_size, input_size, cell_size) {}
void operator()(
OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, typename TTypes<T>::ConstMatrix x,
OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix h_prev,
typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
@ -159,9 +148,9 @@ struct GRUBlockCellBprop : public GRUCell {
// [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T
typename TTypes<T>::ConstMatrix const_d_c_bar(d_c_bar.data(),
d_c_bar.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, stream, d, false, true,
T(1), const_d_c_bar, w_c,
T(0), d_x_comp2_and_h_prevr);
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, d, false, true, T(1),
const_d_c_bar, w_c, T(0),
d_x_comp2_and_h_prevr);
d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends());
d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r);
@ -175,7 +164,7 @@ struct GRUBlockCellBprop : public GRUCell {
typename TTypes<T>::ConstMatrix const_d_r_bar_u_bar(
d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0),
ctx, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0),
d_x_comp1_and_h_prev_comp1);
// d_x = d_x_comp1 + d_x_comp2

View File

@ -34,10 +34,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@ -164,14 +160,10 @@ class LSTMBlockCellOp : public OpKernel {
&icfo_tensor));
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, forget_bias_, cell_clip_, use_peephole_,
ctx, device, forget_bias_, cell_clip_, use_peephole_,
x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
@ -196,22 +188,21 @@ REGISTER_KERNEL(float);
#if GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, const T forget_bias, const T cell_clip, \
bool use_peephole, typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \
\
#define DECLARE_GPU_SPEC(T) \
template <> \
void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, const T forget_bias, \
const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \
typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \
\
extern template struct LSTMBlockCellFprop<GPUDevice, T, true>;
DECLARE_GPU_SPEC(float);
@ -445,10 +436,6 @@ class LSTMBlockCellGradOp : public OpKernel {
&di_tensor));
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>());
functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>());
@ -456,7 +443,7 @@ class LSTMBlockCellGradOp : public OpKernel {
functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, use_peephole_, x_tensor->matrix<T>(),
ctx, device, use_peephole_, x_tensor->matrix<T>(),
cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(),
@ -486,8 +473,7 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void LSTMBlockCellBprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, bool use_peephole, \
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \
@ -769,10 +755,6 @@ class BlockLSTMOp : public OpKernel {
&icfo_tensor));
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
SliceHelper<Device, T> slicer(ctx);
@ -794,7 +776,7 @@ class BlockLSTMOp : public OpKernel {
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, forget_bias_, cell_clip_, use_peephole_,
ctx, device, forget_bias_, cell_clip_, use_peephole_,
x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
@ -1020,10 +1002,6 @@ class BlockLSTMGradOp : public OpKernel {
const Device& device = ctx->eigen_device<Device>();
perftools::gputools::Stream* stream =
std::is_same<Device, GPUDevice>::value
? ctx->op_device_context()->stream()
: nullptr;
functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<float>());
functor::TensorZero<Device, T>()(device,
@ -1073,7 +1051,7 @@ class BlockLSTMGradOp : public OpKernel {
Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, use_peephole_, x_tensor.matrix<T>(),
ctx, device, use_peephole_, x_tensor.matrix<T>(),
cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(),
@ -1134,8 +1112,7 @@ namespace functor {
\
template <> \
void BlockLSTMBprop<GPUDevice, T, true>::operator()( \
OpKernelContext* ctx, perftools::gputools::Stream* stream, \
const GPUDevice& d, bool use_peephole, \
OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \
typename TTypes<T>::ConstMatrix x, \
typename TTypes<T>::ConstMatrix cs_prev, \
typename TTypes<T>::ConstMatrix h_prev, \

View File

@ -22,12 +22,6 @@ limitations under the License.
#include "tensorflow/core/kernels/eigen_activations.h"
#include "tensorflow/core/platform/types.h"
namespace perftools {
namespace gputools {
class Stream;
} // end namespace gputools
} // end namespace perftools
namespace tensorflow {
class OpKernelContext;
@ -153,29 +147,26 @@ struct LSTMBlockCellFprop : public LSTMBlockCell {
const int cell_size)
: LSTMBlockCell(batch_size, input_size, cell_size) {}
void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, const T forget_bias, const T cell_clip,
bool use_peephole, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev,
typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci,
typename TTypes<T>::ConstVec wcf,
typename TTypes<T>::ConstVec wco,
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,
typename TTypes<T>::Matrix icfo,
typename TTypes<T>::Matrix h) {
void operator()(
OpKernelContext* ctx, const Device& d, const T forget_bias,
const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i,
typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f,
typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci,
typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
typename TTypes<T>::Matrix h) {
// Concat xh = [x, h].
xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev;
// states1 = xh * w + b
typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, false, false, T(1), const_xh, w, T(0), icfo);
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(ctx, d, false, false, T(1),
const_xh, w, T(0), icfo);
Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
@ -239,8 +230,8 @@ struct LSTMBlockCellBprop : public LSTMBlockCell {
: LSTMBlockCell(batch_size, input_size, cell_size) {}
void operator()(
OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x,
OpKernelContext* ctx, const Device& d, bool use_peephole,
typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
@ -305,8 +296,8 @@ struct BlockLSTMBprop : public LSTMBlockCell {
: LSTMBlockCell(batch_size, input_size, cell_size) {}
void operator()(
OpKernelContext* ctx, perftools::gputools::Stream* stream,
const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x,
OpKernelContext* ctx, const Device& d, bool use_peephole,
typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
@ -364,7 +355,7 @@ struct BlockLSTMBprop : public LSTMBlockCell {
typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(),
dicfo.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, false, true, T(1), const_dicfo, w, T(0), xh_grad);
ctx, d, false, true, T(1), const_dicfo, w, T(0), xh_grad);
// xh.
xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
@ -377,7 +368,7 @@ struct BlockLSTMBprop : public LSTMBlockCell {
// w_grad.
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, stream, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad);
ctx, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad);
// b_grad.
b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0}));