Make LayerNormBasicLSTMCell compatible with datatypes other than float32 (#12209)

This commit is contained in:
Maximilian Bachl 2017-11-07 18:55:28 +01:00 committed by Martin Wicke
parent db430c4bb2
commit 00e097241e

View File

@ -1362,24 +1362,25 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
def output_size(self):
return self._num_units
def _norm(self, inp, scope):
def _norm(self, inp, scope, dtype=dtypes.float32):
shape = inp.get_shape()[-1:]
gamma_init = init_ops.constant_initializer(self._norm_gain)
beta_init = init_ops.constant_initializer(self._norm_shift)
with vs.variable_scope(scope):
# Initialize beta and gamma for use by layer_norm.
vs.get_variable("gamma", shape=shape, initializer=gamma_init)
vs.get_variable("beta", shape=shape, initializer=beta_init)
vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype)
vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype)
normalized = layers.layer_norm(inp, reuse=True, scope=scope)
return normalized
def _linear(self, args):
out_size = 4 * self._num_units
proj_size = args.get_shape()[-1]
weights = vs.get_variable("kernel", [proj_size, out_size])
dtype = args.dtype
weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype)
out = math_ops.matmul(args, weights)
if not self._layer_norm:
bias = vs.get_variable("bias", [out_size])
bias = vs.get_variable("bias", [out_size], dtype=dtype)
out = nn_ops.bias_add(out, bias)
return out
@ -1388,13 +1389,14 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
c, h = state
args = array_ops.concat([inputs, h], 1)
concat = self._linear(args)
dtype = args.dtype
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
if self._layer_norm:
i = self._norm(i, "input")
j = self._norm(j, "transform")
f = self._norm(f, "forget")
o = self._norm(o, "output")
i = self._norm(i, "input", dtype=dtype)
j = self._norm(j, "transform", dtype=dtype)
f = self._norm(f, "forget", dtype=dtype)
o = self._norm(o, "output", dtype=dtype)
g = self._activation(j)
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
@ -1403,7 +1405,7 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
new_c = (c * math_ops.sigmoid(f + self._forget_bias)
+ math_ops.sigmoid(i) * g)
if self._layer_norm:
new_c = self._norm(new_c, "state")
new_c = self._norm(new_c, "state", dtype=dtype)
new_h = self._activation(new_c) * math_ops.sigmoid(o)
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)