Update LSTM/GRU to support masking inputs with CuDNN kernel.
Since CuDNN kernel only support right padded data, the GPU specific function has been updated with a tf cond to check that. If the batch of the data meet that criteria, then it could use the CuDNN kernel, otherwise it will fallback to use the normal kernel on GPU. PiperOrigin-RevId: 259607726
This commit is contained in:
parent
4de4b9b511
commit
c33f1d1a61
@ -626,9 +626,63 @@ class GRUGraphRewriteTest(keras_parameterized.TestCase):
|
||||
model = keras.models.Model(inputs=inputs, outputs=[outputs, runtime])
|
||||
self._test_runtime_with_model(model)
|
||||
|
||||
def test_GRU_runtime_with_mask(self):
|
||||
# Masking will affect which backend is selected based on whether the mask
|
||||
# is strictly right padded.
|
||||
layer = rnn.GRU(self.rnn_state_size, return_runtime=True)
|
||||
|
||||
inputs = keras.layers.Input(
|
||||
shape=[self.timestep, self.input_shape], dtype=dtypes.float32)
|
||||
masked_inputs = keras.layers.Masking()(inputs)
|
||||
|
||||
outputs, runtime = layer(masked_inputs)
|
||||
# Expand the runtime so that it is a 1D tensor instead of scalar.
|
||||
# TF model does not work with scalar model output, specially during
|
||||
# aggregation.
|
||||
runtime = keras.layers.Lambda(
|
||||
lambda x: array_ops.expand_dims(x, axis=-1))(runtime)
|
||||
model = keras.models.Model(inputs=inputs, outputs=[outputs, runtime])
|
||||
|
||||
(x_train, y_train), _ = testing_utils.get_test_data(
|
||||
train_samples=self.batch,
|
||||
test_samples=0,
|
||||
input_shape=(self.timestep, self.input_shape),
|
||||
num_classes=self.output_shape)
|
||||
y_train = keras.utils.to_categorical(y_train, self.output_shape)
|
||||
|
||||
model.compile(optimizer='sgd',
|
||||
loss=['categorical_crossentropy', None],
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
run_distributed=testing_utils.should_run_distributed())
|
||||
|
||||
model.fit(x_train, y_train)
|
||||
|
||||
# Verify unpadded data.
|
||||
_, runtime_value = model.predict(x_train)
|
||||
if test.is_gpu_available():
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_GPU)
|
||||
else:
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
|
||||
|
||||
# Update x/y to be right padded by setting the last timestep to 0
|
||||
x_train[:, -1, :] = 0
|
||||
y_train[:, -1] = 0
|
||||
_, runtime_value = model.predict(x_train)
|
||||
if test.is_gpu_available():
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_GPU)
|
||||
else:
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
|
||||
|
||||
# Further update x/y to be mix padded (masks in the middle), and verify
|
||||
# only cpu kernel can be selected.
|
||||
x_train[:, -3, :] = 0
|
||||
y_train[:, -3] = 0
|
||||
_, runtime_value = model.predict(x_train)
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
|
||||
|
||||
# Due to b/120160788.
|
||||
@test_util.run_v2_only
|
||||
def test_UnifiedGRU_with_cond(self):
|
||||
def test_GRU_runtime_with_cond(self):
|
||||
# This test is to demonstrate the graph rewrite of grappler plugin under
|
||||
# the condition that the function returns different number of internal
|
||||
# states.
|
||||
|
@ -769,6 +769,7 @@ class LSTMGraphRewriteTest(keras_parameterized.TestCase):
|
||||
|
||||
model.compile(optimizer='sgd',
|
||||
loss=['categorical_crossentropy', None],
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
run_distributed=testing_utils.should_run_distributed())
|
||||
|
||||
existing_loss = 0
|
||||
@ -800,6 +801,60 @@ class LSTMGraphRewriteTest(keras_parameterized.TestCase):
|
||||
model = keras.models.Model(inputs=inputs, outputs=[outputs, runtime])
|
||||
self._test_runtime_with_model(model)
|
||||
|
||||
def test_LSTM_runtime_with_mask(self):
|
||||
# Masking will affect which backend is selected based on whether the mask
|
||||
# is strictly right padded.
|
||||
layer = rnn.LSTM(self.rnn_state_size, return_runtime=True)
|
||||
|
||||
inputs = keras.layers.Input(
|
||||
shape=[self.timestep, self.input_shape], dtype=dtypes.float32)
|
||||
masked_inputs = keras.layers.Masking()(inputs)
|
||||
|
||||
outputs, runtime = layer(masked_inputs)
|
||||
# Expand the runtime so that it is a 1D tensor instead of scalar.
|
||||
# TF model does not work with scalar model output, specially during
|
||||
# aggregation.
|
||||
runtime = keras.layers.Lambda(
|
||||
lambda x: array_ops.expand_dims(x, axis=-1))(runtime)
|
||||
model = keras.models.Model(inputs=inputs, outputs=[outputs, runtime])
|
||||
|
||||
(x_train, y_train), _ = testing_utils.get_test_data(
|
||||
train_samples=self.batch,
|
||||
test_samples=0,
|
||||
input_shape=(self.timestep, self.input_shape),
|
||||
num_classes=self.output_shape)
|
||||
y_train = keras.utils.to_categorical(y_train, self.output_shape)
|
||||
|
||||
model.compile(optimizer='sgd',
|
||||
loss=['categorical_crossentropy', None],
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
run_distributed=testing_utils.should_run_distributed())
|
||||
|
||||
model.fit(x_train, y_train)
|
||||
|
||||
# Verify unpadded data.
|
||||
_, runtime_value = model.predict(x_train)
|
||||
if test.is_gpu_available():
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_GPU)
|
||||
else:
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
|
||||
|
||||
# Update x/y to be right padded by setting the last timestep to 0
|
||||
x_train[:, -1, :] = 0
|
||||
y_train[:, -1] = 0
|
||||
_, runtime_value = model.predict(x_train)
|
||||
if test.is_gpu_available():
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_GPU)
|
||||
else:
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
|
||||
|
||||
# Further update x/y to be mix padded (masks in the middle), and verify
|
||||
# only cpu kernel can be selected.
|
||||
x_train[:, -3, :] = 0
|
||||
y_train[:, -3] = 0
|
||||
_, runtime_value = model.predict(x_train)
|
||||
self.assertEqual(runtime_value[0], rnn._RUNTIME_CPU)
|
||||
|
||||
# Due to b/120160788.
|
||||
@test_util.run_v2_only
|
||||
def test_LSTM_runtime_with_cond(self):
|
||||
|
@ -399,22 +399,8 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
|
||||
else:
|
||||
last_output, outputs, new_h, runtime = standard_gru(**normal_gru_kwargs)
|
||||
else:
|
||||
if mask is None:
|
||||
last_output, outputs, new_h, runtime = gru_with_backend_selection(
|
||||
normal_gru_kwargs, cudnn_gru_kwargs)
|
||||
else:
|
||||
def with_mask_support():
|
||||
# TODO(b/134702514): Change to use backend selection.
|
||||
# return gru_with_backend_selection(normal_gru_kwargs,
|
||||
# cudnn_gru_kwargs)
|
||||
return standard_gru(**normal_gru_kwargs)
|
||||
def without_mask_support():
|
||||
return standard_gru(**normal_gru_kwargs)
|
||||
|
||||
last_output, outputs, new_h, runtime = control_flow_ops.cond(
|
||||
is_sequence_right_padded(mask, self.time_major),
|
||||
true_fn=with_mask_support,
|
||||
false_fn=without_mask_support)
|
||||
last_output, outputs, new_h, runtime = gru_with_backend_selection(
|
||||
**normal_gru_kwargs)
|
||||
|
||||
states = [new_h]
|
||||
return last_output, outputs, runtime, states
|
||||
@ -568,7 +554,9 @@ def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
|
||||
return last_output, outputs, h, _runtime(_RUNTIME_GPU)
|
||||
|
||||
|
||||
def gru_with_backend_selection(normal_gru_params, cudnn_gru_params):
|
||||
def gru_with_backend_selection(
|
||||
inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
|
||||
go_backwards, activation, recurrent_activation):
|
||||
"""Call the GRU with optimized backend kernel selection.
|
||||
|
||||
Under the hood, this function will create two TF function, one with the most
|
||||
@ -581,12 +569,69 @@ def gru_with_backend_selection(normal_gru_params, cudnn_gru_params):
|
||||
device placement.
|
||||
|
||||
Args:
|
||||
normal_gru_params: Dict, parameters for the generic TF function.
|
||||
cudnn_gru_params: Dict, parameters for the CuDNN specific TF function.
|
||||
inputs: Input tensor of GRU layer.
|
||||
init_h: Initial state tensor for the cell output.
|
||||
kernel: Weights for cell kernel.
|
||||
recurrent_kernel: Weights for cell recurrent kernel.
|
||||
bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
|
||||
is used in this case.
|
||||
mask: Boolean tensor for mask out the steps within sequence.
|
||||
time_major: Boolean, whether the inputs are in the format of
|
||||
[time, batch, feature] or [batch, time, feature].
|
||||
go_backwards: Boolean (default False). If True, process the input sequence
|
||||
backwards and return the reversed sequence.
|
||||
activation: Activation function to use for output.
|
||||
recurrent_activation: Activation function to use for hidden recurrent state.
|
||||
|
||||
Returns:
|
||||
List of output tensors, same as standard_gru.
|
||||
"""
|
||||
params = {
|
||||
'inputs': inputs,
|
||||
'init_h': init_h,
|
||||
'kernel': kernel,
|
||||
'recurrent_kernel': recurrent_kernel,
|
||||
'bias': bias,
|
||||
'mask': mask,
|
||||
'time_major': time_major,
|
||||
'go_backwards': go_backwards,
|
||||
'activation': activation,
|
||||
'recurrent_activation': recurrent_activation
|
||||
}
|
||||
|
||||
def cudnn_gru_with_fallback(inputs, init_h, kernel, recurrent_kernel,
|
||||
bias, mask, time_major, go_backwards, activation,
|
||||
recurrent_activation):
|
||||
"""Use CuDNN kernel when mask is none or strictly right padded."""
|
||||
if mask is None:
|
||||
return cudnn_gru(inputs=inputs, init_h=init_h, kernel=kernel,
|
||||
recurrent_kernel=recurrent_kernel, bias=bias, mask=mask,
|
||||
time_major=time_major, go_backwards=go_backwards)
|
||||
# Note that mask is a boolean tensor, which doesn't need to do gradient
|
||||
# calculation, when using tf.cond, a default gradient is added for it,
|
||||
# which then cause the backward function to have a signature mismatch.
|
||||
# Force the mask to not generate gradient to allow implementation_selector
|
||||
# to work properly.
|
||||
# TODO(b/80444525): Remove the stop_gradient().
|
||||
mask = array_ops.stop_gradient(mask)
|
||||
|
||||
def input_right_padded():
|
||||
return cudnn_gru(inputs=inputs, init_h=init_h, kernel=kernel,
|
||||
recurrent_kernel=recurrent_kernel, bias=bias, mask=mask,
|
||||
time_major=time_major, go_backwards=go_backwards)
|
||||
|
||||
def input_not_right_padded():
|
||||
return standard_gru(inputs=inputs, init_h=init_h, kernel=kernel,
|
||||
recurrent_kernel=recurrent_kernel, bias=bias,
|
||||
mask=mask, time_major=time_major,
|
||||
go_backwards=go_backwards, activation=activation,
|
||||
recurrent_activation=recurrent_activation)
|
||||
|
||||
return control_flow_ops.cond(
|
||||
is_sequence_right_padded(mask, time_major),
|
||||
true_fn=input_right_padded,
|
||||
false_fn=input_not_right_padded)
|
||||
|
||||
# Each time a `tf.function` is called, we will give it a unique
|
||||
# identifiable API name, so that Grappler won't get confused when it
|
||||
# sees multiple GRU layers added into same graph, and it will be able
|
||||
@ -595,14 +640,12 @@ def gru_with_backend_selection(normal_gru_params, cudnn_gru_params):
|
||||
defun_standard_gru = _generate_defun_backend(
|
||||
api_name, _CPU_DEVICE_NAME, standard_gru)
|
||||
defun_cudnn_gru = _generate_defun_backend(
|
||||
api_name, _GPU_DEVICE_NAME, cudnn_gru)
|
||||
api_name, _GPU_DEVICE_NAME, cudnn_gru_with_fallback)
|
||||
|
||||
# Call the normal GRU impl and register the CuDNN impl function. The
|
||||
# grappler will kick in during session execution to optimize the graph.
|
||||
last_output, outputs, new_h, runtime = defun_standard_gru(
|
||||
**normal_gru_params)
|
||||
|
||||
function.register(defun_cudnn_gru, **cudnn_gru_params)
|
||||
last_output, outputs, new_h, runtime = defun_standard_gru(**params)
|
||||
function.register(defun_cudnn_gru, **params)
|
||||
return last_output, outputs, new_h, runtime
|
||||
|
||||
|
||||
@ -919,24 +962,8 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
|
||||
last_output, outputs, new_h, new_c, runtime = standard_lstm(
|
||||
**normal_lstm_kwargs)
|
||||
else:
|
||||
if mask is None:
|
||||
(last_output, outputs,
|
||||
new_h, new_c, runtime) = lstm_with_backend_selection(
|
||||
normal_lstm_kwargs, cudnn_lstm_kwargs)
|
||||
else:
|
||||
def with_mask_support():
|
||||
# TODO(b/134702514): Change to use backend selection.
|
||||
# return lstm_with_backend_selection(normal_lstm_kwargs,
|
||||
# cudnn_lstm_kwargs)
|
||||
return standard_lstm(**normal_lstm_kwargs)
|
||||
def without_mask_support():
|
||||
return standard_lstm(**normal_lstm_kwargs)
|
||||
|
||||
(last_output, outputs,
|
||||
new_h, new_c, runtime) = control_flow_ops.cond(
|
||||
is_sequence_right_padded(mask, self.time_major),
|
||||
true_fn=with_mask_support,
|
||||
false_fn=without_mask_support)
|
||||
(last_output, outputs, new_h, new_c,
|
||||
runtime) = lstm_with_backend_selection(**normal_lstm_kwargs)
|
||||
|
||||
states = [new_h, new_c]
|
||||
|
||||
@ -1162,7 +1189,9 @@ def cudnn_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
|
||||
return last_output, outputs, h, c, _runtime(_RUNTIME_GPU)
|
||||
|
||||
|
||||
def lstm_with_backend_selection(normal_lstm_params, cudnn_lstm_params):
|
||||
def lstm_with_backend_selection(
|
||||
inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, time_major,
|
||||
go_backwards, activation, recurrent_activation):
|
||||
"""Call the LSTM with optimized backend kernel selection.
|
||||
|
||||
Under the hood, this function will create two TF function, one with the most
|
||||
@ -1175,12 +1204,73 @@ def lstm_with_backend_selection(normal_lstm_params, cudnn_lstm_params):
|
||||
device placement.
|
||||
|
||||
Args:
|
||||
normal_lstm_params: Dict, parameters for the generic TF function.
|
||||
cudnn_lstm_params: Dict, parameters for the CuDNN specific TF function.
|
||||
inputs: Input tensor of LSTM layer.
|
||||
init_h: Initial state tensor for the cell output.
|
||||
init_c: Initial state tensor for the cell hidden state.
|
||||
kernel: Weights for cell kernel.
|
||||
recurrent_kernel: Weights for cell recurrent kernel.
|
||||
bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
|
||||
is used in this case.
|
||||
mask: Boolean tensor for mask out the steps within sequence.
|
||||
time_major: Boolean, whether the inputs are in the format of
|
||||
[time, batch, feature] or [batch, time, feature].
|
||||
go_backwards: Boolean (default False). If True, process the input sequence
|
||||
backwards and return the reversed sequence.
|
||||
activation: Activation function to use for output.
|
||||
recurrent_activation: Activation function to use for hidden recurrent state.
|
||||
|
||||
Returns:
|
||||
List of output tensors, same as standard_lstm.
|
||||
"""
|
||||
params = {
|
||||
'inputs': inputs,
|
||||
'init_h': init_h,
|
||||
'init_c': init_c,
|
||||
'kernel': kernel,
|
||||
'recurrent_kernel': recurrent_kernel,
|
||||
'bias': bias,
|
||||
'mask': mask,
|
||||
'time_major': time_major,
|
||||
'go_backwards': go_backwards,
|
||||
'activation': activation,
|
||||
'recurrent_activation': recurrent_activation
|
||||
}
|
||||
|
||||
def cudnn_lstm_with_fallback(inputs, init_h, init_c, kernel, recurrent_kernel,
|
||||
bias, mask, time_major, go_backwards, activation,
|
||||
recurrent_activation):
|
||||
"""Use CuDNN kernel when mask is none or strictly right padded."""
|
||||
if mask is None:
|
||||
return cudnn_lstm(inputs=inputs, init_h=init_h, init_c=init_c,
|
||||
kernel=kernel, recurrent_kernel=recurrent_kernel,
|
||||
bias=bias, mask=mask, time_major=time_major,
|
||||
go_backwards=go_backwards)
|
||||
# Note that mask is a boolean tensor, which doesn't need to do gradient
|
||||
# calculation, when using tf.cond, a default gradient is added for it,
|
||||
# which then cause the backward function to have a signature mismatch.
|
||||
# Force the mask to not generate gradient to allow implementation_selector
|
||||
# to work properly.
|
||||
# TODO(b/80444525): Remove the stop_gradient().
|
||||
mask = array_ops.stop_gradient(mask)
|
||||
|
||||
def input_right_padded():
|
||||
return cudnn_lstm(inputs=inputs, init_h=init_h, init_c=init_c,
|
||||
kernel=kernel, recurrent_kernel=recurrent_kernel,
|
||||
bias=bias, mask=mask, time_major=time_major,
|
||||
go_backwards=go_backwards)
|
||||
|
||||
def input_not_right_padded():
|
||||
return standard_lstm(inputs=inputs, init_h=init_h, init_c=init_c,
|
||||
kernel=kernel, recurrent_kernel=recurrent_kernel,
|
||||
bias=bias, mask=mask, time_major=time_major,
|
||||
go_backwards=go_backwards, activation=activation,
|
||||
recurrent_activation=recurrent_activation)
|
||||
|
||||
return control_flow_ops.cond(
|
||||
is_sequence_right_padded(mask, time_major),
|
||||
true_fn=input_right_padded,
|
||||
false_fn=input_not_right_padded)
|
||||
|
||||
# Each time a `tf.function` is called, we will give it a unique
|
||||
# identifiable API name, so that Grappler won't get confused when it
|
||||
# sees multiple LSTM layers added into same graph, and it will be able
|
||||
@ -1189,14 +1279,14 @@ def lstm_with_backend_selection(normal_lstm_params, cudnn_lstm_params):
|
||||
defun_standard_lstm = _generate_defun_backend(
|
||||
api_name, _CPU_DEVICE_NAME, standard_lstm)
|
||||
defun_cudnn_lstm = _generate_defun_backend(
|
||||
api_name, _GPU_DEVICE_NAME, cudnn_lstm)
|
||||
api_name, _GPU_DEVICE_NAME, cudnn_lstm_with_fallback)
|
||||
|
||||
# Call the normal LSTM impl and register the CuDNN impl function. The
|
||||
# grappler will kick in during session execution to optimize the graph.
|
||||
last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
|
||||
**normal_lstm_params)
|
||||
**params)
|
||||
function.register(defun_cudnn_lstm, **params)
|
||||
|
||||
function.register(defun_cudnn_lstm, **cudnn_lstm_params)
|
||||
return last_output, outputs, new_h, new_c, runtime
|
||||
|
||||
|
||||
@ -1264,7 +1354,8 @@ def _generate_defun_backend(unique_api_name, preferred_device, func):
|
||||
_DEFUN_DEVICE_ATTRIBUTE: preferred_device,
|
||||
}
|
||||
return function.defun_with_attributes(func=func,
|
||||
attributes=function_attributes)
|
||||
attributes=function_attributes,
|
||||
autograph=False)
|
||||
|
||||
|
||||
def _get_context_device_type():
|
||||
|
Loading…
Reference in New Issue
Block a user