Do not pass input_length to K.rnn in RNN layer since it is redundant.

The input_length arg is passed as the maximum_iterations arg to tf.while_loop which adds a LogicalAnd to the loop condition which is slow on GPU.

PiperOrigin-RevId: 261822039
This commit is contained in:
Saurabh Saxena 2019-08-05 19:30:28 -07:00 committed by TensorFlower Gardener
parent 476fd9f8da
commit 640b5f2513
2 changed files with 13 additions and 1 deletions

View File

@ -492,6 +492,7 @@ py_library(
":generic_utils",
":tf_utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:cudnn_rnn_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:embedding_ops",

View File

@ -24,6 +24,7 @@ import collections
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
@ -35,6 +36,7 @@ from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable
@ -732,6 +734,15 @@ class RNN(Layer):
new_states = [new_states]
return output, new_states
# `input_length` is passed as the `maximum_iterations` arg to tf.while_loop.
# We only specify that when building for XLA since that causes slowdowns
# on GPU in TF.
if (not context.executing_eagerly() and
control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
input_length = timesteps
else:
input_length = None
last_output, outputs, states = K.rnn(
step,
inputs,
@ -740,7 +751,7 @@ class RNN(Layer):
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
input_length=timesteps,
input_length=input_length,
time_major=self.time_major,
zero_output_for_mask=self.zero_output_for_mask)
if self.stateful: