Do not pass input_length to K.rnn in RNN layer since it is redundant.
The input_length arg is passed as the maximum_iterations arg to tf.while_loop which adds a LogicalAnd to the loop condition which is slow on GPU. PiperOrigin-RevId: 261822039
This commit is contained in:
parent
476fd9f8da
commit
640b5f2513
@ -492,6 +492,7 @@ py_library(
|
|||||||
":generic_utils",
|
":generic_utils",
|
||||||
":tf_utils",
|
":tf_utils",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:control_flow_util",
|
||||||
"//tensorflow/python:cudnn_rnn_ops_gen",
|
"//tensorflow/python:cudnn_rnn_ops_gen",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:embedding_ops",
|
"//tensorflow/python:embedding_ops",
|
||||||
|
@ -24,6 +24,7 @@ import collections
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.keras import activations
|
from tensorflow.python.keras import activations
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
@ -35,6 +36,7 @@ from tensorflow.python.keras.engine.input_spec import InputSpec
|
|||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
@ -732,6 +734,15 @@ class RNN(Layer):
|
|||||||
new_states = [new_states]
|
new_states = [new_states]
|
||||||
return output, new_states
|
return output, new_states
|
||||||
|
|
||||||
|
# `input_length` is passed as the `maximum_iterations` arg to tf.while_loop.
|
||||||
|
# We only specify that when building for XLA since that causes slowdowns
|
||||||
|
# on GPU in TF.
|
||||||
|
if (not context.executing_eagerly() and
|
||||||
|
control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())):
|
||||||
|
input_length = timesteps
|
||||||
|
else:
|
||||||
|
input_length = None
|
||||||
|
|
||||||
last_output, outputs, states = K.rnn(
|
last_output, outputs, states = K.rnn(
|
||||||
step,
|
step,
|
||||||
inputs,
|
inputs,
|
||||||
@ -740,7 +751,7 @@ class RNN(Layer):
|
|||||||
go_backwards=self.go_backwards,
|
go_backwards=self.go_backwards,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
unroll=self.unroll,
|
unroll=self.unroll,
|
||||||
input_length=timesteps,
|
input_length=input_length,
|
||||||
time_major=self.time_major,
|
time_major=self.time_major,
|
||||||
zero_output_for_mask=self.zero_output_for_mask)
|
zero_output_for_mask=self.zero_output_for_mask)
|
||||||
if self.stateful:
|
if self.stateful:
|
||||||
|
Loading…
Reference in New Issue
Block a user