Keras RNN performance fix for RNNEsitmator with multiworker.

1. Make the default variable caching device to the same as the device that consume the variable. This is the key part that avoid repetitive read of variable within the tf.while loop, which heavily reduce the performance in PS and remote worker setting.

2. Update the RNN estimator to use v2 Keras implementation for cell, which use more bulk matmul rather than sliding the weights.

3. Minor change to the GRU cell implementation. Replacing array index read with tf.split.

PiperOrigin-RevId: 272926308
This commit is contained in:
Scott Zhu 2019-10-04 11:52:41 -07:00 committed by TensorFlower Gardener
parent 45793ed5e2
commit fe1178ab88
2 changed files with 87 additions and 19 deletions

View File

@ -344,8 +344,8 @@ class Layer(module.Module):
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
`tf.VariableAggregation`.
**kwargs: Additional keyword arguments. Accepted values are `getter` and
`collections`.
**kwargs: Additional keyword arguments. Accepted values are `getter`,
`collections`, `experimental_autocast` and `caching_device`.
Returns:
The created variable. Usually either a `Variable` or `ResourceVariable`
@ -362,13 +362,16 @@ class Layer(module.Module):
shape = ()
# Validate optional keyword arguments.
for kwarg in kwargs:
if kwarg not in ['getter', 'collections', 'experimental_autocast']:
if kwarg not in ['getter', 'collections', 'experimental_autocast',
'caching_device']:
raise TypeError('Unknown keyword argument:', kwarg)
getter = kwargs.pop('getter', base_layer_utils.make_variable)
collections_arg = kwargs.pop('collections', None)
# 'experimental_autocast' can be set to False by the caller to indicate an
# AutoCastVariable should never be created.
autocast = kwargs.pop('experimental_autocast', True)
# See the docstring for tf.Variable about the details for caching_device.
caching_device = kwargs.pop('caching_device', None)
if dtype is None:
dtype = self.dtype or backend.floatx()
@ -414,6 +417,13 @@ class Layer(module.Module):
def getter(*args, **kwargs): # pylint: disable=function-redefined
variable = old_getter(*args, **kwargs)
return autocast_variable.create_autocast_variable(variable)
# Also the caching_device does not work with the mixed precision API,
# disable it if it is specified.
# TODO(b/142020079): Reenable it once the bug is fixed.
if caching_device is not None:
tf_logging.warn('`caching_device` does not work with mixed precision '
'API. Ignoring user specified `caching_device`.')
caching_device = None
variable = self._add_variable_with_custom_getter(
name=name,
@ -431,7 +441,8 @@ class Layer(module.Module):
use_resource=use_resource,
collections=collections_arg,
synchronization=synchronization,
aggregation=aggregation)
aggregation=aggregation,
caching_device=caching_device)
backend.track_variable(variable)
if regularizer is not None:

View File

@ -24,6 +24,7 @@ import collections
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
@ -36,6 +37,7 @@ from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops.ragged import ragged_tensor
@ -1307,25 +1309,29 @@ class SimpleRNNCell(DropoutRNNCellMixin, Layer):
@tf_utils.shape_type_conversion
def build(self, input_shape):
default_caching_device = _caching_device(self)
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
constraint=self.kernel_constraint,
caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
constraint=self.recurrent_constraint,
caching_device=default_caching_device)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.units,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
constraint=self.bias_constraint,
caching_device=default_caching_device)
else:
self.bias = None
self.built = True
@ -1737,18 +1743,21 @@ class GRUCell(DropoutRNNCellMixin, Layer):
@tf_utils.shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[-1]
default_caching_device = _caching_device(self)
self.kernel = self.add_weight(
shape=(input_dim, self.units * 3),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
constraint=self.kernel_constraint,
caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 3),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
constraint=self.recurrent_constraint,
caching_device=default_caching_device)
if self.use_bias:
if not self.reset_after:
@ -1763,7 +1772,8 @@ class GRUCell(DropoutRNNCellMixin, Layer):
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
constraint=self.bias_constraint,
caching_device=default_caching_device)
else:
self.bias = None
self.built = True
@ -1841,9 +1851,7 @@ class GRUCell(DropoutRNNCellMixin, Layer):
# biases: bias_z_i, bias_r_i, bias_h_i
matrix_x = K.bias_add(matrix_x, input_bias)
x_z = matrix_x[:, :self.units]
x_r = matrix_x[:, self.units: 2 * self.units]
x_h = matrix_x[:, 2 * self.units:]
x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=-1)
if 0. < self.recurrent_dropout < 1.:
h_tm1 = h_tm1 * rec_dp_mask[0]
@ -1857,14 +1865,14 @@ class GRUCell(DropoutRNNCellMixin, Layer):
# hidden state projected separately for update/reset and new
matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
recurrent_z = matrix_inner[:, :self.units]
recurrent_r = matrix_inner[:, self.units:2 * self.units]
recurrent_z, recurrent_r, recurrent_h = array_ops.split(
matrix_inner, [self.units, self.units, -1], axis=-1)
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
if self.reset_after:
recurrent_h = r * matrix_inner[:, 2 * self.units:]
recurrent_h = r * recurrent_h
else:
recurrent_h = K.dot(r * h_tm1,
self.recurrent_kernel[:, 2 * self.units:])
@ -2292,19 +2300,22 @@ class LSTMCell(DropoutRNNCellMixin, Layer):
@tf_utils.shape_type_conversion
def build(self, input_shape):
default_caching_device = _caching_device(self)
input_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
constraint=self.kernel_constraint,
caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
constraint=self.recurrent_constraint,
caching_device=default_caching_device)
if self.use_bias:
if self.unit_forget_bias:
@ -2322,7 +2333,8 @@ class LSTMCell(DropoutRNNCellMixin, Layer):
name='bias',
initializer=bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
constraint=self.bias_constraint,
caching_device=default_caching_device)
else:
self.bias = None
self.built = True
@ -2911,3 +2923,48 @@ def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
return nest.map_structure(create_zeros, state_size)
else:
return create_zeros(state_size)
def _caching_device(rnn_cell):
"""Returns the caching device for the RNN variable.
This is useful for distributed training, when variable is not located as same
device as the training worker. By enabling the device cache, this allows
worker to read the variable once and cache locally, rather than read it every
time step from remote when it is needed.
Note that this is assuming the variable that cell needs for each time step is
having the same value in the forward path, and only gets updated in the
backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the
cell body relies on any variable that gets updated every time step, then
caching device will cause it to read the stall value.
Args:
rnn_cell: the rnn cell instance.
"""
if context.executing_eagerly():
# caching_device is not supported in eager mode.
return None
# Don't set a caching device when running in a loop, since it is possible that
# train steps could be wrapped in a tf.while_loop. In that scenario caching
# prevents forward computations in loop iterations from re-reading the
# updated weights.
if control_flow_util.IsInWhileLoop(ops.get_default_graph()):
logging.warn('Variable read device caching has been disabled because the '
'RNN is in tf.while_loop loop context, which will cause '
'reading stalled value in forward path. This could slow down '
'the training due to duplicated variable reads. Please '
'consider updating your code to remove tf.while_loop if '
'possible.')
return None
if rnn_cell._dtype_policy.should_cast_variables:
logging.warn('Variable read device caching has been disabled since it '
'doesn\'t work with the mixed precision API. This is '
'likely to cause a slowdown for RNN training due to '
'duplicated read of variable for each timestep, which '
'will be significant in a multi remote worker setting. '
'Please consider disabling mixed precision API if '
'the performance has been affected.')
return None
# Cache the value on the device that access the variable.
return lambda op: op.device