Update model save/load for unified_lstm layer.
The bias weights during load is changed a bit. The original approach is to split the weights into half and give them to recurrent bias and input gate bias. In the new approach, since it is unknown that whether the newly constructed layer will be running on CPU or GPU, if we still split the weight into half, then the input_gate_bias value is lost since canonical LSTM only use recurrent bias. The new approach is to give recurrent_bias 100% of the sum, and 0% to input_gate_bias. PiperOrigin-RevId: 224360388
This commit is contained in:
parent
322254f7e0
commit
4f6613441e
@ -551,10 +551,68 @@ def preprocess_weights_for_loading(layer,
|
||||
if layer.__class__.__name__ == 'ConvLSTM2D':
|
||||
weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
|
||||
|
||||
weights = _convert_unified_lstm_weights(layer, weights)
|
||||
|
||||
# convert CuDNN layers
|
||||
return _convert_rnn_weights(layer, weights)
|
||||
|
||||
|
||||
def _convert_unified_lstm_weights(layer, weights):
|
||||
"""Converts weights for Unified LSTM layer.
|
||||
|
||||
The input weights suppose to have 2, 3 or 4 items.
|
||||
1. kernel. (i, f, c, o gates concat among axis 1)
|
||||
2. recurrent_kernel. (i, f, c, o concat among axis 1)
|
||||
3. recurrent_bias. (optional, only available when use bias)
|
||||
4. input_bias (optional, only available when use bias and cudnn).
|
||||
Kernel and recurrent_kernel does not need any conversion. During load(),
|
||||
since the layer could be built with the parameter that does not support the
|
||||
defun approach, it is possible that cudnn_bias variable is not created, or
|
||||
even created but not used during actual run. Because of that, we sum up the
|
||||
value of two biases, and give it to recurrent_bias only. Mathematically, the
|
||||
LSTM is calculated as following formula:
|
||||
|
||||
i_t = sigmoid(w_i * x_t + r_i * h_(t-1) + b_wi + b_ri)
|
||||
f_t = sigmoid(w_f * x_t + r_f * h_(t-1) + b_wf + b_rf)
|
||||
o_t = sigmoid(w_o * x_t + r_o * h_(t-1) + b_wo + b_ro)
|
||||
c'_t = tanh(w_c * x_t + r_c * h_(t-1) + b_wc + b_rc)
|
||||
c_t = f_t . c_(t-1) + i_t . c'_t
|
||||
h_t = o_t . tanh(c_t)
|
||||
|
||||
Note that b_w{x} is the input_bias, and b_r{x} is the recurrent_bias.
|
||||
Since it is a linear add, it is fine to give b_r{x} 100% and b_w{x} 0%, as
|
||||
long as the sum are the same.
|
||||
|
||||
Args:
|
||||
layer: The keras layer that will be loaded with weights.
|
||||
weights: the list of numpy arrays which hold the weights to be loaded.
|
||||
|
||||
Returns:
|
||||
weights: the processed list of numpy arrays.
|
||||
"""
|
||||
if layer.__class__.__name__ == 'UnifiedLSTM':
|
||||
if len(weights) not in [3, 4]:
|
||||
# Only handles the bias conversion in this function, in the case when
|
||||
# bias is not used or weights in unexpected length, do nothing and return.
|
||||
return weights
|
||||
|
||||
if len(weights) == 3:
|
||||
recurrent_bias = weights[2]
|
||||
else:
|
||||
# Add all the bias value to recurrent_bias
|
||||
recurrent_bias = weights[2] + weights[3]
|
||||
|
||||
if len(layer.weights) == 3:
|
||||
weights = weights[:2] + [recurrent_bias]
|
||||
elif len(layer.weights) == 4:
|
||||
# Create a zero filled input_bias, since all the weights have given
|
||||
# to recurrent bias.
|
||||
input_bias = np.zeros_like(recurrent_bias)
|
||||
weights = weights[:2] + [recurrent_bias, input_bias]
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def _convert_rnn_weights(layer, weights):
|
||||
"""Converts weights for RNN layers between native and CuDNN format.
|
||||
|
||||
|
@ -221,6 +221,70 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
for (x, y) in zip(weights1, weights2)
|
||||
]
|
||||
|
||||
@parameterized.named_parameters(
|
||||
# test_name, use_bias, bias_initializer, activation
|
||||
('normal', True, 'zeros', 'tanh'),
|
||||
('no_bias', False, 'zeros', 'tanh'),
|
||||
# TODO(scottzhu): Reenable this test case when the approach is decided.
|
||||
# ('random_bias', True, 'random_uniform', 'tanh'),
|
||||
('no_cudnn_bias', True, 'zeros', 'relu')
|
||||
)
|
||||
def test_process_weights_for_loading_unified_lstm(
|
||||
self, use_bias, bias_initializer, activation):
|
||||
if h5py is None:
|
||||
return
|
||||
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
h5_path = os.path.join(temp_dir, 'test.h5')
|
||||
|
||||
batch = 10
|
||||
timestep = 3
|
||||
input_dim = 5
|
||||
units = 2
|
||||
|
||||
x = np.random.random((batch, timestep, input_dim))
|
||||
|
||||
def build_model():
|
||||
inputs = keras.layers.Input(
|
||||
shape=[timestep, input_dim], dtype=dtypes.float32)
|
||||
layer = keras.layers.UnifiedLSTM(
|
||||
units,
|
||||
activation=activation,
|
||||
use_bias=use_bias,
|
||||
bias_initializer=bias_initializer)
|
||||
output = layer(inputs)
|
||||
return keras.models.Model(inputs, output), layer
|
||||
|
||||
with self.cached_session():
|
||||
model, layer = build_model()
|
||||
y_ref = model.predict(x)
|
||||
model.save_weights(h5_path)
|
||||
|
||||
cloned_model, new_layer = build_model()
|
||||
cloned_model.load_weights(h5_path)
|
||||
y = cloned_model.predict(x)
|
||||
|
||||
self.assertAllClose(y, y_ref)
|
||||
|
||||
# Test the individual layer weights.
|
||||
weights1 = layer.get_weights()
|
||||
weights2 = new_layer.get_weights()
|
||||
self.assertLen(weights1, len(weights2))
|
||||
# kernel and current kernel should be the same.
|
||||
self.assertAllClose(weights1[:2], weights2[:2])
|
||||
|
||||
if len(weights2) >= 3:
|
||||
# Test recurrent bias
|
||||
expected_recurrent_bias = weights1[2]
|
||||
if len(weights1) == 4:
|
||||
expected_recurrent_bias += weights1[3]
|
||||
self.assertAllClose(weights2[2], expected_recurrent_bias)
|
||||
|
||||
if len(weights2) == 4:
|
||||
# Test recovered input_gate_bias to be always zero
|
||||
self.assertAllClose(weights2[3], np.zeros_like(weights1[3]))
|
||||
|
||||
def test_sequential_weight_loading(self):
|
||||
if h5py is None:
|
||||
return
|
||||
|
@ -149,6 +149,7 @@ from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell
|
||||
from tensorflow.python.keras.layers.recurrent import SimpleRNN
|
||||
from tensorflow.python.keras.layers.recurrent import GRU
|
||||
from tensorflow.python.keras.layers.recurrent import LSTM
|
||||
from tensorflow.python.keras.layers.recurrent import UnifiedLSTM
|
||||
|
||||
# Convolutional-recurrent layers.
|
||||
from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D
|
||||
|
Loading…
Reference in New Issue
Block a user