Replace the usages of cudnn_rnn
, cudnn_rnnv2
and cudnn_rnnv3
with TF public apis.
PiperOrigin-RevId: 341494517 Change-Id: I9ffcbfb5bc5f0be58ab5a557e67ecf93a8fe7e6b
This commit is contained in:
parent
e396a0ce01
commit
51dc6eca9c
@ -504,7 +504,7 @@ class CuDNNLSTM(_CuDNNRNN):
|
||||
'is_training': True,
|
||||
}
|
||||
|
||||
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args)
|
||||
outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV2(**args)
|
||||
|
||||
if self.stateful or self.return_state:
|
||||
h = h[0]
|
||||
|
@ -698,9 +698,9 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
|
||||
if go_backwards:
|
||||
# Reverse axis 0 since the input is already convert to time major.
|
||||
inputs = array_ops.reverse(inputs, axis=[0])
|
||||
outputs, h, _, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
||||
inputs, input_h=init_h, input_c=0, params=params, is_training=True,
|
||||
rnn_mode='gru')
|
||||
outputs, h, _, _ = gen_cudnn_rnn_ops.CudnnRNN(
|
||||
input=inputs, input_h=init_h, input_c=0, params=params,
|
||||
is_training=True, rnn_mode='gru')
|
||||
|
||||
last_output = outputs[-1]
|
||||
if not time_major and mask is None:
|
||||
@ -1486,8 +1486,8 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
|
||||
# expected_output = [0, 0, 6, 5 ,4]
|
||||
inputs = array_ops.reverse_sequence_v2(
|
||||
inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
|
||||
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
|
||||
inputs,
|
||||
outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV3(
|
||||
input=inputs,
|
||||
input_h=init_h,
|
||||
input_c=init_c,
|
||||
params=params,
|
||||
@ -1506,9 +1506,9 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
|
||||
if go_backwards:
|
||||
# Reverse axis 0 since the input is already convert to time major.
|
||||
inputs = array_ops.reverse(inputs, axis=[0])
|
||||
outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
||||
inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
|
||||
rnn_mode='lstm')
|
||||
outputs, h, c, _ = gen_cudnn_rnn_ops.CudnnRNN(
|
||||
input=inputs, input_h=init_h, input_c=init_c, params=params,
|
||||
is_training=True, rnn_mode='lstm')
|
||||
|
||||
last_output = outputs[-1]
|
||||
if not time_major and mask is None:
|
||||
|
Loading…
Reference in New Issue
Block a user