Do not pass input_length to K.rnn in RNN layer since it is redundant.
The input_length arg is passed as the maximum_iterations arg to tf.while_loop which adds a LogicalAnd to the loop condition which is slow on GPU. PiperOrigin-RevId: 261822039
This commit is contained in:
parent
476fd9f8da
commit
640b5f2513
@ -492,6 +492,7 @@ py_library(
|
||||
":generic_utils",
|
||||
":tf_utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:cudnn_rnn_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:embedding_ops",
|
||||
|
@ -24,6 +24,7 @@ import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import activations
|
||||
from tensorflow.python.keras import backend as K
|
||||
@ -35,6 +36,7 @@ from tensorflow.python.keras.engine.input_spec import InputSpec
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
@ -732,6 +734,15 @@ class RNN(Layer):
|
||||
new_states = [new_states]
|
||||
return output, new_states
|
||||
|
||||
# `input_length` is passed as the `maximum_iterations` arg to tf.while_loop.
|
||||
# We only specify that when building for XLA since that causes slowdowns
|
||||
# on GPU in TF.
|
||||
if (not context.executing_eagerly() and
|
||||
control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
|
||||
input_length = timesteps
|
||||
else:
|
||||
input_length = None
|
||||
|
||||
last_output, outputs, states = K.rnn(
|
||||
step,
|
||||
inputs,
|
||||
@ -740,7 +751,7 @@ class RNN(Layer):
|
||||
go_backwards=self.go_backwards,
|
||||
mask=mask,
|
||||
unroll=self.unroll,
|
||||
input_length=timesteps,
|
||||
input_length=input_length,
|
||||
time_major=self.time_major,
|
||||
zero_output_for_mask=self.zero_output_for_mask)
|
||||
if self.stateful:
|
||||
|
Loading…
Reference in New Issue
Block a user