Keras RNN performance fix for RNNEsitmator with multiworker.
1. Make the default variable caching device to the same as the device that consume the variable. This is the key part that avoid repetitive read of variable within the tf.while loop, which heavily reduce the performance in PS and remote worker setting. 2. Update the RNN estimator to use v2 Keras implementation for cell, which use more bulk matmul rather than sliding the weights. 3. Minor change to the GRU cell implementation. Replacing array index read with tf.split. PiperOrigin-RevId: 272926308
This commit is contained in:
parent
45793ed5e2
commit
fe1178ab88
@ -344,8 +344,8 @@ class Layer(module.Module):
|
||||
aggregation: Indicates how a distributed variable will be aggregated.
|
||||
Accepted values are constants defined in the class
|
||||
`tf.VariableAggregation`.
|
||||
**kwargs: Additional keyword arguments. Accepted values are `getter` and
|
||||
`collections`.
|
||||
**kwargs: Additional keyword arguments. Accepted values are `getter`,
|
||||
`collections`, `experimental_autocast` and `caching_device`.
|
||||
|
||||
Returns:
|
||||
The created variable. Usually either a `Variable` or `ResourceVariable`
|
||||
@ -362,13 +362,16 @@ class Layer(module.Module):
|
||||
shape = ()
|
||||
# Validate optional keyword arguments.
|
||||
for kwarg in kwargs:
|
||||
if kwarg not in ['getter', 'collections', 'experimental_autocast']:
|
||||
if kwarg not in ['getter', 'collections', 'experimental_autocast',
|
||||
'caching_device']:
|
||||
raise TypeError('Unknown keyword argument:', kwarg)
|
||||
getter = kwargs.pop('getter', base_layer_utils.make_variable)
|
||||
collections_arg = kwargs.pop('collections', None)
|
||||
# 'experimental_autocast' can be set to False by the caller to indicate an
|
||||
# AutoCastVariable should never be created.
|
||||
autocast = kwargs.pop('experimental_autocast', True)
|
||||
# See the docstring for tf.Variable about the details for caching_device.
|
||||
caching_device = kwargs.pop('caching_device', None)
|
||||
|
||||
if dtype is None:
|
||||
dtype = self.dtype or backend.floatx()
|
||||
@ -414,6 +417,13 @@ class Layer(module.Module):
|
||||
def getter(*args, **kwargs): # pylint: disable=function-redefined
|
||||
variable = old_getter(*args, **kwargs)
|
||||
return autocast_variable.create_autocast_variable(variable)
|
||||
# Also the caching_device does not work with the mixed precision API,
|
||||
# disable it if it is specified.
|
||||
# TODO(b/142020079): Reenable it once the bug is fixed.
|
||||
if caching_device is not None:
|
||||
tf_logging.warn('`caching_device` does not work with mixed precision '
|
||||
'API. Ignoring user specified `caching_device`.')
|
||||
caching_device = None
|
||||
|
||||
variable = self._add_variable_with_custom_getter(
|
||||
name=name,
|
||||
@ -431,7 +441,8 @@ class Layer(module.Module):
|
||||
use_resource=use_resource,
|
||||
collections=collections_arg,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation)
|
||||
aggregation=aggregation,
|
||||
caching_device=caching_device)
|
||||
backend.track_variable(variable)
|
||||
|
||||
if regularizer is not None:
|
||||
|
@ -24,6 +24,7 @@ import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import activations
|
||||
from tensorflow.python.keras import backend as K
|
||||
@ -36,6 +37,7 @@ from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -1307,25 +1309,29 @@ class SimpleRNNCell(DropoutRNNCellMixin, Layer):
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def build(self, input_shape):
|
||||
default_caching_device = _caching_device(self)
|
||||
self.kernel = self.add_weight(
|
||||
shape=(input_shape[-1], self.units),
|
||||
name='kernel',
|
||||
initializer=self.kernel_initializer,
|
||||
regularizer=self.kernel_regularizer,
|
||||
constraint=self.kernel_constraint)
|
||||
constraint=self.kernel_constraint,
|
||||
caching_device=default_caching_device)
|
||||
self.recurrent_kernel = self.add_weight(
|
||||
shape=(self.units, self.units),
|
||||
name='recurrent_kernel',
|
||||
initializer=self.recurrent_initializer,
|
||||
regularizer=self.recurrent_regularizer,
|
||||
constraint=self.recurrent_constraint)
|
||||
constraint=self.recurrent_constraint,
|
||||
caching_device=default_caching_device)
|
||||
if self.use_bias:
|
||||
self.bias = self.add_weight(
|
||||
shape=(self.units,),
|
||||
name='bias',
|
||||
initializer=self.bias_initializer,
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint)
|
||||
constraint=self.bias_constraint,
|
||||
caching_device=default_caching_device)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
@ -1737,18 +1743,21 @@ class GRUCell(DropoutRNNCellMixin, Layer):
|
||||
@tf_utils.shape_type_conversion
|
||||
def build(self, input_shape):
|
||||
input_dim = input_shape[-1]
|
||||
default_caching_device = _caching_device(self)
|
||||
self.kernel = self.add_weight(
|
||||
shape=(input_dim, self.units * 3),
|
||||
name='kernel',
|
||||
initializer=self.kernel_initializer,
|
||||
regularizer=self.kernel_regularizer,
|
||||
constraint=self.kernel_constraint)
|
||||
constraint=self.kernel_constraint,
|
||||
caching_device=default_caching_device)
|
||||
self.recurrent_kernel = self.add_weight(
|
||||
shape=(self.units, self.units * 3),
|
||||
name='recurrent_kernel',
|
||||
initializer=self.recurrent_initializer,
|
||||
regularizer=self.recurrent_regularizer,
|
||||
constraint=self.recurrent_constraint)
|
||||
constraint=self.recurrent_constraint,
|
||||
caching_device=default_caching_device)
|
||||
|
||||
if self.use_bias:
|
||||
if not self.reset_after:
|
||||
@ -1763,7 +1772,8 @@ class GRUCell(DropoutRNNCellMixin, Layer):
|
||||
name='bias',
|
||||
initializer=self.bias_initializer,
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint)
|
||||
constraint=self.bias_constraint,
|
||||
caching_device=default_caching_device)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
@ -1841,9 +1851,7 @@ class GRUCell(DropoutRNNCellMixin, Layer):
|
||||
# biases: bias_z_i, bias_r_i, bias_h_i
|
||||
matrix_x = K.bias_add(matrix_x, input_bias)
|
||||
|
||||
x_z = matrix_x[:, :self.units]
|
||||
x_r = matrix_x[:, self.units: 2 * self.units]
|
||||
x_h = matrix_x[:, 2 * self.units:]
|
||||
x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=-1)
|
||||
|
||||
if 0. < self.recurrent_dropout < 1.:
|
||||
h_tm1 = h_tm1 * rec_dp_mask[0]
|
||||
@ -1857,14 +1865,14 @@ class GRUCell(DropoutRNNCellMixin, Layer):
|
||||
# hidden state projected separately for update/reset and new
|
||||
matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
|
||||
|
||||
recurrent_z = matrix_inner[:, :self.units]
|
||||
recurrent_r = matrix_inner[:, self.units:2 * self.units]
|
||||
recurrent_z, recurrent_r, recurrent_h = array_ops.split(
|
||||
matrix_inner, [self.units, self.units, -1], axis=-1)
|
||||
|
||||
z = self.recurrent_activation(x_z + recurrent_z)
|
||||
r = self.recurrent_activation(x_r + recurrent_r)
|
||||
|
||||
if self.reset_after:
|
||||
recurrent_h = r * matrix_inner[:, 2 * self.units:]
|
||||
recurrent_h = r * recurrent_h
|
||||
else:
|
||||
recurrent_h = K.dot(r * h_tm1,
|
||||
self.recurrent_kernel[:, 2 * self.units:])
|
||||
@ -2292,19 +2300,22 @@ class LSTMCell(DropoutRNNCellMixin, Layer):
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def build(self, input_shape):
|
||||
default_caching_device = _caching_device(self)
|
||||
input_dim = input_shape[-1]
|
||||
self.kernel = self.add_weight(
|
||||
shape=(input_dim, self.units * 4),
|
||||
name='kernel',
|
||||
initializer=self.kernel_initializer,
|
||||
regularizer=self.kernel_regularizer,
|
||||
constraint=self.kernel_constraint)
|
||||
constraint=self.kernel_constraint,
|
||||
caching_device=default_caching_device)
|
||||
self.recurrent_kernel = self.add_weight(
|
||||
shape=(self.units, self.units * 4),
|
||||
name='recurrent_kernel',
|
||||
initializer=self.recurrent_initializer,
|
||||
regularizer=self.recurrent_regularizer,
|
||||
constraint=self.recurrent_constraint)
|
||||
constraint=self.recurrent_constraint,
|
||||
caching_device=default_caching_device)
|
||||
|
||||
if self.use_bias:
|
||||
if self.unit_forget_bias:
|
||||
@ -2322,7 +2333,8 @@ class LSTMCell(DropoutRNNCellMixin, Layer):
|
||||
name='bias',
|
||||
initializer=bias_initializer,
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint)
|
||||
constraint=self.bias_constraint,
|
||||
caching_device=default_caching_device)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
@ -2911,3 +2923,48 @@ def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
|
||||
return nest.map_structure(create_zeros, state_size)
|
||||
else:
|
||||
return create_zeros(state_size)
|
||||
|
||||
|
||||
def _caching_device(rnn_cell):
|
||||
"""Returns the caching device for the RNN variable.
|
||||
|
||||
This is useful for distributed training, when variable is not located as same
|
||||
device as the training worker. By enabling the device cache, this allows
|
||||
worker to read the variable once and cache locally, rather than read it every
|
||||
time step from remote when it is needed.
|
||||
|
||||
Note that this is assuming the variable that cell needs for each time step is
|
||||
having the same value in the forward path, and only gets updated in the
|
||||
backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the
|
||||
cell body relies on any variable that gets updated every time step, then
|
||||
caching device will cause it to read the stall value.
|
||||
|
||||
Args:
|
||||
rnn_cell: the rnn cell instance.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
# caching_device is not supported in eager mode.
|
||||
return None
|
||||
# Don't set a caching device when running in a loop, since it is possible that
|
||||
# train steps could be wrapped in a tf.while_loop. In that scenario caching
|
||||
# prevents forward computations in loop iterations from re-reading the
|
||||
# updated weights.
|
||||
if control_flow_util.IsInWhileLoop(ops.get_default_graph()):
|
||||
logging.warn('Variable read device caching has been disabled because the '
|
||||
'RNN is in tf.while_loop loop context, which will cause '
|
||||
'reading stalled value in forward path. This could slow down '
|
||||
'the training due to duplicated variable reads. Please '
|
||||
'consider updating your code to remove tf.while_loop if '
|
||||
'possible.')
|
||||
return None
|
||||
if rnn_cell._dtype_policy.should_cast_variables:
|
||||
logging.warn('Variable read device caching has been disabled since it '
|
||||
'doesn\'t work with the mixed precision API. This is '
|
||||
'likely to cause a slowdown for RNN training due to '
|
||||
'duplicated read of variable for each timestep, which '
|
||||
'will be significant in a multi remote worker setting. '
|
||||
'Please consider disabling mixed precision API if '
|
||||
'the performance has been affected.')
|
||||
return None
|
||||
# Cache the value on the device that access the variable.
|
||||
return lambda op: op.device
|
||||
|
Loading…
Reference in New Issue
Block a user