Replace the usages of cudnn_rnn
, cudnn_rnnv2
and cudnn_rnnv3
with TF public apis.
PiperOrigin-RevId: 341494517 Change-Id: I9ffcbfb5bc5f0be58ab5a557e67ecf93a8fe7e6b
This commit is contained in:
parent
e396a0ce01
commit
51dc6eca9c
@ -504,7 +504,7 @@ class CuDNNLSTM(_CuDNNRNN):
|
|||||||
'is_training': True,
|
'is_training': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args)
|
outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV2(**args)
|
||||||
|
|
||||||
if self.stateful or self.return_state:
|
if self.stateful or self.return_state:
|
||||||
h = h[0]
|
h = h[0]
|
||||||
|
@ -698,9 +698,9 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
|
|||||||
if go_backwards:
|
if go_backwards:
|
||||||
# Reverse axis 0 since the input is already convert to time major.
|
# Reverse axis 0 since the input is already convert to time major.
|
||||||
inputs = array_ops.reverse(inputs, axis=[0])
|
inputs = array_ops.reverse(inputs, axis=[0])
|
||||||
outputs, h, _, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
outputs, h, _, _ = gen_cudnn_rnn_ops.CudnnRNN(
|
||||||
inputs, input_h=init_h, input_c=0, params=params, is_training=True,
|
input=inputs, input_h=init_h, input_c=0, params=params,
|
||||||
rnn_mode='gru')
|
is_training=True, rnn_mode='gru')
|
||||||
|
|
||||||
last_output = outputs[-1]
|
last_output = outputs[-1]
|
||||||
if not time_major and mask is None:
|
if not time_major and mask is None:
|
||||||
@ -1486,8 +1486,8 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
|
|||||||
# expected_output = [0, 0, 6, 5 ,4]
|
# expected_output = [0, 0, 6, 5 ,4]
|
||||||
inputs = array_ops.reverse_sequence_v2(
|
inputs = array_ops.reverse_sequence_v2(
|
||||||
inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
|
inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
|
||||||
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
|
outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV3(
|
||||||
inputs,
|
input=inputs,
|
||||||
input_h=init_h,
|
input_h=init_h,
|
||||||
input_c=init_c,
|
input_c=init_c,
|
||||||
params=params,
|
params=params,
|
||||||
@ -1506,9 +1506,9 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
|
|||||||
if go_backwards:
|
if go_backwards:
|
||||||
# Reverse axis 0 since the input is already convert to time major.
|
# Reverse axis 0 since the input is already convert to time major.
|
||||||
inputs = array_ops.reverse(inputs, axis=[0])
|
inputs = array_ops.reverse(inputs, axis=[0])
|
||||||
outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
outputs, h, c, _ = gen_cudnn_rnn_ops.CudnnRNN(
|
||||||
inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
|
input=inputs, input_h=init_h, input_c=init_c, params=params,
|
||||||
rnn_mode='lstm')
|
is_training=True, rnn_mode='lstm')
|
||||||
|
|
||||||
last_output = outputs[-1]
|
last_output = outputs[-1]
|
||||||
if not time_major and mask is None:
|
if not time_major and mask is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user