Fix: fixed LSTMBlockCell cuda kernel
This commit is contained in:
parent
e45c9f2272
commit
50c58837d2
@ -350,8 +350,8 @@ __global__ void lstm_gates_bprop(
|
||||
di[cid] = di_local;
|
||||
|
||||
dgates[gid + 0 * cell_size] = di_local;
|
||||
dgates[gate_c_offset(gate_layout, cell_size)] = dci_local;
|
||||
dgates[gate_f_offset(gate_layout, cell_size)] = df_local;
|
||||
dgates[gid + gate_c_offset(gate_layout, cell_size)] = dci_local;
|
||||
dgates[gid + gate_f_offset(gate_layout, cell_size)] = df_local;
|
||||
dgates[gid + 3 * cell_size] = do_local;
|
||||
|
||||
cs_prev_grad[cid] = dcs_local * f_local;
|
||||
|
@ -4371,7 +4371,7 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
cuda_py_test(
|
||||
name = "rnn_grad_test",
|
||||
srcs = ["ops/rnn_grad_test.py"],
|
||||
python_version = "PY3",
|
||||
|
@ -66,6 +66,60 @@ class RNNGradTest(test.TestCase):
|
||||
self.assertAllEqual(w_grad, w_ifco_grad)
|
||||
self.assertAllEqual(b_grad, b_ifco_grad)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testLSTMBlockCell(self):
|
||||
batch_size = np.random.randint(1, 32)
|
||||
input_size = np.random.randint(1, 32)
|
||||
hidden_size = np.random.randint(1, 32)
|
||||
w = deterministic_random_uniform(
|
||||
[input_size + hidden_size, 4 * hidden_size])
|
||||
b = deterministic_random_uniform([4 * hidden_size])
|
||||
x = deterministic_random_uniform([batch_size, input_size])
|
||||
cs_prev = h_prev = deterministic_random_uniform([batch_size, hidden_size])
|
||||
w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype)
|
||||
cs_grad = deterministic_random_uniform([batch_size, hidden_size])
|
||||
h_grad = deterministic_random_uniform([batch_size, hidden_size])
|
||||
|
||||
outputs = []
|
||||
grads = []
|
||||
for use_gpu in [False, True]:
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
output = gen_rnn_ops.lstm_block_cell(
|
||||
x=x,
|
||||
cs_prev=cs_prev,
|
||||
h_prev=h_prev,
|
||||
w=w,
|
||||
wci=w_peephole,
|
||||
wcf=w_peephole,
|
||||
wco=w_peephole,
|
||||
b=b,
|
||||
forget_bias=1.0,
|
||||
cell_clip=0.0,
|
||||
use_peephole=False)
|
||||
(i, cs, f, o, ci, co, _) = output
|
||||
grad = gen_rnn_ops.lstm_block_cell_grad(
|
||||
x=x,
|
||||
cs_prev=cs_prev,
|
||||
h_prev=h_prev,
|
||||
w=w,
|
||||
wci=w_peephole,
|
||||
wcf=w_peephole,
|
||||
wco=w_peephole,
|
||||
b=b,
|
||||
i=i,
|
||||
cs=cs,
|
||||
f=f,
|
||||
o=o,
|
||||
ci=ci,
|
||||
co=co,
|
||||
cs_grad=cs_grad,
|
||||
h_grad=h_grad,
|
||||
use_peephole=False)
|
||||
outputs.append(output)
|
||||
grads.append(grad)
|
||||
self.assertAllClose(outputs[0], outputs[1])
|
||||
self.assertAllClose(grads[0], grads[1])
|
||||
|
||||
def _lstm_block(self, op, w, b, x, cs_prev, h_prev):
|
||||
w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype)
|
||||
_, all_cs, _, _, _, _, all_h = op(
|
||||
|
Loading…
x
Reference in New Issue
Block a user