Make LayerNormBasicLSTMCell compatible with datatypes other than float32 (#12209)
This commit is contained in:
parent
db430c4bb2
commit
00e097241e
@ -1362,24 +1362,25 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def _norm(self, inp, scope):
|
||||
def _norm(self, inp, scope, dtype=dtypes.float32):
|
||||
shape = inp.get_shape()[-1:]
|
||||
gamma_init = init_ops.constant_initializer(self._norm_gain)
|
||||
beta_init = init_ops.constant_initializer(self._norm_shift)
|
||||
with vs.variable_scope(scope):
|
||||
# Initialize beta and gamma for use by layer_norm.
|
||||
vs.get_variable("gamma", shape=shape, initializer=gamma_init)
|
||||
vs.get_variable("beta", shape=shape, initializer=beta_init)
|
||||
vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype)
|
||||
vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype)
|
||||
normalized = layers.layer_norm(inp, reuse=True, scope=scope)
|
||||
return normalized
|
||||
|
||||
def _linear(self, args):
|
||||
out_size = 4 * self._num_units
|
||||
proj_size = args.get_shape()[-1]
|
||||
weights = vs.get_variable("kernel", [proj_size, out_size])
|
||||
dtype = args.dtype
|
||||
weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype)
|
||||
out = math_ops.matmul(args, weights)
|
||||
if not self._layer_norm:
|
||||
bias = vs.get_variable("bias", [out_size])
|
||||
bias = vs.get_variable("bias", [out_size], dtype=dtype)
|
||||
out = nn_ops.bias_add(out, bias)
|
||||
return out
|
||||
|
||||
@ -1388,13 +1389,14 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
|
||||
c, h = state
|
||||
args = array_ops.concat([inputs, h], 1)
|
||||
concat = self._linear(args)
|
||||
dtype = args.dtype
|
||||
|
||||
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
||||
if self._layer_norm:
|
||||
i = self._norm(i, "input")
|
||||
j = self._norm(j, "transform")
|
||||
f = self._norm(f, "forget")
|
||||
o = self._norm(o, "output")
|
||||
i = self._norm(i, "input", dtype=dtype)
|
||||
j = self._norm(j, "transform", dtype=dtype)
|
||||
f = self._norm(f, "forget", dtype=dtype)
|
||||
o = self._norm(o, "output", dtype=dtype)
|
||||
|
||||
g = self._activation(j)
|
||||
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
|
||||
@ -1403,7 +1405,7 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
|
||||
new_c = (c * math_ops.sigmoid(f + self._forget_bias)
|
||||
+ math_ops.sigmoid(i) * g)
|
||||
if self._layer_norm:
|
||||
new_c = self._norm(new_c, "state")
|
||||
new_c = self._norm(new_c, "state", dtype=dtype)
|
||||
new_h = self._activation(new_c) * math_ops.sigmoid(o)
|
||||
|
||||
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
|
||||
|
Loading…
x
Reference in New Issue
Block a user