Fix: fixed LSTMBlockCell cuda kernel
This commit is contained in:
parent
e45c9f2272
commit
50c58837d2
@ -350,8 +350,8 @@ __global__ void lstm_gates_bprop(
|
|||||||
di[cid] = di_local;
|
di[cid] = di_local;
|
||||||
|
|
||||||
dgates[gid + 0 * cell_size] = di_local;
|
dgates[gid + 0 * cell_size] = di_local;
|
||||||
dgates[gate_c_offset(gate_layout, cell_size)] = dci_local;
|
dgates[gid + gate_c_offset(gate_layout, cell_size)] = dci_local;
|
||||||
dgates[gate_f_offset(gate_layout, cell_size)] = df_local;
|
dgates[gid + gate_f_offset(gate_layout, cell_size)] = df_local;
|
||||||
dgates[gid + 3 * cell_size] = do_local;
|
dgates[gid + 3 * cell_size] = do_local;
|
||||||
|
|
||||||
cs_prev_grad[cid] = dcs_local * f_local;
|
cs_prev_grad[cid] = dcs_local * f_local;
|
||||||
|
@ -4371,7 +4371,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
cuda_py_test(
|
||||||
name = "rnn_grad_test",
|
name = "rnn_grad_test",
|
||||||
srcs = ["ops/rnn_grad_test.py"],
|
srcs = ["ops/rnn_grad_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
|
@ -66,6 +66,60 @@ class RNNGradTest(test.TestCase):
|
|||||||
self.assertAllEqual(w_grad, w_ifco_grad)
|
self.assertAllEqual(w_grad, w_ifco_grad)
|
||||||
self.assertAllEqual(b_grad, b_ifco_grad)
|
self.assertAllEqual(b_grad, b_ifco_grad)
|
||||||
|
|
||||||
|
@test_util.deprecated_graph_mode_only
|
||||||
|
def testLSTMBlockCell(self):
|
||||||
|
batch_size = np.random.randint(1, 32)
|
||||||
|
input_size = np.random.randint(1, 32)
|
||||||
|
hidden_size = np.random.randint(1, 32)
|
||||||
|
w = deterministic_random_uniform(
|
||||||
|
[input_size + hidden_size, 4 * hidden_size])
|
||||||
|
b = deterministic_random_uniform([4 * hidden_size])
|
||||||
|
x = deterministic_random_uniform([batch_size, input_size])
|
||||||
|
cs_prev = h_prev = deterministic_random_uniform([batch_size, hidden_size])
|
||||||
|
w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype)
|
||||||
|
cs_grad = deterministic_random_uniform([batch_size, hidden_size])
|
||||||
|
h_grad = deterministic_random_uniform([batch_size, hidden_size])
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
grads = []
|
||||||
|
for use_gpu in [False, True]:
|
||||||
|
with self.cached_session(use_gpu=use_gpu):
|
||||||
|
output = gen_rnn_ops.lstm_block_cell(
|
||||||
|
x=x,
|
||||||
|
cs_prev=cs_prev,
|
||||||
|
h_prev=h_prev,
|
||||||
|
w=w,
|
||||||
|
wci=w_peephole,
|
||||||
|
wcf=w_peephole,
|
||||||
|
wco=w_peephole,
|
||||||
|
b=b,
|
||||||
|
forget_bias=1.0,
|
||||||
|
cell_clip=0.0,
|
||||||
|
use_peephole=False)
|
||||||
|
(i, cs, f, o, ci, co, _) = output
|
||||||
|
grad = gen_rnn_ops.lstm_block_cell_grad(
|
||||||
|
x=x,
|
||||||
|
cs_prev=cs_prev,
|
||||||
|
h_prev=h_prev,
|
||||||
|
w=w,
|
||||||
|
wci=w_peephole,
|
||||||
|
wcf=w_peephole,
|
||||||
|
wco=w_peephole,
|
||||||
|
b=b,
|
||||||
|
i=i,
|
||||||
|
cs=cs,
|
||||||
|
f=f,
|
||||||
|
o=o,
|
||||||
|
ci=ci,
|
||||||
|
co=co,
|
||||||
|
cs_grad=cs_grad,
|
||||||
|
h_grad=h_grad,
|
||||||
|
use_peephole=False)
|
||||||
|
outputs.append(output)
|
||||||
|
grads.append(grad)
|
||||||
|
self.assertAllClose(outputs[0], outputs[1])
|
||||||
|
self.assertAllClose(grads[0], grads[1])
|
||||||
|
|
||||||
def _lstm_block(self, op, w, b, x, cs_prev, h_prev):
|
def _lstm_block(self, op, w, b, x, cs_prev, h_prev):
|
||||||
w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype)
|
w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype)
|
||||||
_, all_cs, _, _, _, _, all_h = op(
|
_, all_cs, _, _, _, _, all_h = op(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user