Fix: fixed LSTMBlockCell cuda kernel

This commit is contained in:
vladbataev 2020-02-10 21:02:18 +03:00
parent e45c9f2272
commit 50c58837d2
3 changed files with 57 additions and 3 deletions

View File

@ -350,8 +350,8 @@ __global__ void lstm_gates_bprop(
di[cid] = di_local;
dgates[gid + 0 * cell_size] = di_local;
dgates[gate_c_offset(gate_layout, cell_size)] = dci_local;
dgates[gate_f_offset(gate_layout, cell_size)] = df_local;
dgates[gid + gate_c_offset(gate_layout, cell_size)] = dci_local;
dgates[gid + gate_f_offset(gate_layout, cell_size)] = df_local;
dgates[gid + 3 * cell_size] = do_local;
cs_prev_grad[cid] = dcs_local * f_local;

View File

@ -4371,7 +4371,7 @@ py_library(
],
)
py_test(
cuda_py_test(
name = "rnn_grad_test",
srcs = ["ops/rnn_grad_test.py"],
python_version = "PY3",

View File

@ -66,6 +66,60 @@ class RNNGradTest(test.TestCase):
self.assertAllEqual(w_grad, w_ifco_grad)
self.assertAllEqual(b_grad, b_ifco_grad)
@test_util.deprecated_graph_mode_only
def testLSTMBlockCell(self):
batch_size = np.random.randint(1, 32)
input_size = np.random.randint(1, 32)
hidden_size = np.random.randint(1, 32)
w = deterministic_random_uniform(
[input_size + hidden_size, 4 * hidden_size])
b = deterministic_random_uniform([4 * hidden_size])
x = deterministic_random_uniform([batch_size, input_size])
cs_prev = h_prev = deterministic_random_uniform([batch_size, hidden_size])
w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype)
cs_grad = deterministic_random_uniform([batch_size, hidden_size])
h_grad = deterministic_random_uniform([batch_size, hidden_size])
outputs = []
grads = []
for use_gpu in [False, True]:
with self.cached_session(use_gpu=use_gpu):
output = gen_rnn_ops.lstm_block_cell(
x=x,
cs_prev=cs_prev,
h_prev=h_prev,
w=w,
wci=w_peephole,
wcf=w_peephole,
wco=w_peephole,
b=b,
forget_bias=1.0,
cell_clip=0.0,
use_peephole=False)
(i, cs, f, o, ci, co, _) = output
grad = gen_rnn_ops.lstm_block_cell_grad(
x=x,
cs_prev=cs_prev,
h_prev=h_prev,
w=w,
wci=w_peephole,
wcf=w_peephole,
wco=w_peephole,
b=b,
i=i,
cs=cs,
f=f,
o=o,
ci=ci,
co=co,
cs_grad=cs_grad,
h_grad=h_grad,
use_peephole=False)
outputs.append(output)
grads.append(grad)
self.assertAllClose(outputs[0], outputs[1])
self.assertAllClose(grads[0], grads[1])
def _lstm_block(self, op, w, b, x, cs_prev, h_prev):
w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype)
_, all_cs, _, _, _, _, all_h = op(