Replace the usages of cudnn_rnn, cudnn_rnnv2 and cudnn_rnnv3 with TF public apis.

PiperOrigin-RevId: 341494517
Change-Id: I9ffcbfb5bc5f0be58ab5a557e67ecf93a8fe7e6b
This commit is contained in:
Yanhui Liang 2020-11-09 15:11:04 -08:00 committed by TensorFlower Gardener
parent e396a0ce01
commit 51dc6eca9c
2 changed files with 9 additions and 9 deletions

View File

@ -504,7 +504,7 @@ class CuDNNLSTM(_CuDNNRNN):
'is_training': True,
}
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args)
outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV2(**args)
if self.stateful or self.return_state:
h = h[0]

View File

@ -698,9 +698,9 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
if go_backwards:
# Reverse axis 0 since the input is already convert to time major.
inputs = array_ops.reverse(inputs, axis=[0])
outputs, h, _, _ = gen_cudnn_rnn_ops.cudnn_rnn(
inputs, input_h=init_h, input_c=0, params=params, is_training=True,
rnn_mode='gru')
outputs, h, _, _ = gen_cudnn_rnn_ops.CudnnRNN(
input=inputs, input_h=init_h, input_c=0, params=params,
is_training=True, rnn_mode='gru')
last_output = outputs[-1]
if not time_major and mask is None:
@ -1486,8 +1486,8 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
# expected_output = [0, 0, 6, 5 ,4]
inputs = array_ops.reverse_sequence_v2(
inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
inputs,
outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV3(
input=inputs,
input_h=init_h,
input_c=init_c,
params=params,
@ -1506,9 +1506,9 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
if go_backwards:
# Reverse axis 0 since the input is already convert to time major.
inputs = array_ops.reverse(inputs, axis=[0])
outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
rnn_mode='lstm')
outputs, h, c, _ = gen_cudnn_rnn_ops.CudnnRNN(
input=inputs, input_h=init_h, input_c=init_c, params=params,
is_training=True, rnn_mode='lstm')
last_output = outputs[-1]
if not time_major and mask is None: