Fix CuDNNCompatibleGRU after GRUCell refactorization
PiperOrigin-RevId: 175574730
This commit is contained in:
parent
333bdea952
commit
90222dd7b2
@ -18,7 +18,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||||
from tensorflow.contrib.rnn.python.ops import lstm_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import common_shapes
|
||||
@ -29,6 +28,7 @@ from tensorflow.python.layers import base as base_layer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
@ -55,6 +55,11 @@ CUDNN_INPUT_LINEAR_MODE = "linear_input"
|
||||
CUDNN_INPUT_SKIP_MODE = "skip_input"
|
||||
CUDNN_INPUT_AUTO_MODE = "auto_select"
|
||||
|
||||
# pylint:disable=protected-access
|
||||
_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME
|
||||
_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME
|
||||
# pylint:enable=protected-access
|
||||
|
||||
|
||||
class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell):
|
||||
"""Cudnn Compatible LSTMCell.
|
||||
@ -87,9 +92,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell):
|
||||
Cudnn compatible GRU (from Cudnn library user guide):
|
||||
```python
|
||||
r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate
|
||||
i_t = sigma(x_t * W_i + h_t-1 * R_i + b_Wi + b_Ru) # update gate
|
||||
u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate
|
||||
h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate
|
||||
h_t = (1 - i_t) .* h'_t + i_t .* h_t-1
|
||||
h_t = (1 - u_t) .* h'_t + u_t .* h_t-1
|
||||
```
|
||||
|
||||
Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}):
|
||||
@ -112,33 +117,65 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell):
|
||||
reuse=reuse,
|
||||
kernel_initializer=kernel_initializer)
|
||||
|
||||
def build(self, inputs_shape):
|
||||
if inputs_shape[1].value is None:
|
||||
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
|
||||
% inputs_shape)
|
||||
|
||||
input_depth = inputs_shape[1].value
|
||||
self._gate_kernel = self.add_variable(
|
||||
"gates/%s" % _WEIGHTS_VARIABLE_NAME,
|
||||
shape=[input_depth + self._num_units, 2 * self._num_units],
|
||||
initializer=self._kernel_initializer)
|
||||
self._gate_bias = self.add_variable(
|
||||
"gates/%s" % _BIAS_VARIABLE_NAME,
|
||||
shape=[2 * self._num_units],
|
||||
initializer=(
|
||||
self._bias_initializer
|
||||
if self._bias_initializer is not None
|
||||
else init_ops.constant_initializer(1.0, dtype=self.dtype)))
|
||||
|
||||
self._candidate_input_kernel = self.add_variable(
|
||||
"candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME,
|
||||
shape=[input_depth, self._num_units],
|
||||
initializer=self._kernel_initializer)
|
||||
self._candidate_hidden_kernel = self.add_variable(
|
||||
"candidate/hidden_projection/%s" % _WEIGHTS_VARIABLE_NAME,
|
||||
shape=[self._num_units, self._num_units],
|
||||
initializer=self._kernel_initializer)
|
||||
|
||||
self._candidate_input_bias = self.add_variable(
|
||||
"candidate/input_projection/%s" % _BIAS_VARIABLE_NAME,
|
||||
shape=[self._num_units],
|
||||
initializer=(
|
||||
self._bias_initializer
|
||||
if self._bias_initializer is not None
|
||||
else init_ops.zeros_initializer(dtype=self.dtype)))
|
||||
self._candidate_hidden_bias = self.add_variable(
|
||||
"candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME,
|
||||
shape=[self._num_units],
|
||||
initializer=(
|
||||
self._bias_initializer
|
||||
if self._bias_initializer is not None
|
||||
else init_ops.zeros_initializer(dtype=self.dtype)))
|
||||
|
||||
def call(self, inputs, state):
|
||||
"""Gated recurrent unit (GRU) with nunits cells."""
|
||||
with vs.variable_scope("gates"): # Reset gate and update gate.
|
||||
# We start with bias of 1.0 to not reset and not update.
|
||||
bias_ones = self._bias_initializer
|
||||
if self._bias_initializer is None:
|
||||
dtype = inputs.dtype
|
||||
bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
|
||||
# pylint: disable=protected-access
|
||||
value = math_ops.sigmoid(
|
||||
core_rnn_cell._linear([inputs, state], 2 * self._num_units, True,
|
||||
bias_ones, self._kernel_initializer))
|
||||
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
|
||||
# pylint: enable=protected-access
|
||||
with vs.variable_scope("candidate"):
|
||||
# pylint: disable=protected-access
|
||||
with vs.variable_scope("input_projection"):
|
||||
hi = core_rnn_cell._linear(inputs, self._num_units, True,
|
||||
self._bias_initializer,
|
||||
self._kernel_initializer)
|
||||
with vs.variable_scope("hidden_projection"):
|
||||
hh = r * (core_rnn_cell._linear(state, self._num_units, True,
|
||||
self._bias_initializer,
|
||||
self._kernel_initializer))
|
||||
# pylint: enable=protected-access
|
||||
c = self._activation(hi + hh)
|
||||
new_h = u * state + (1 - u) * c
|
||||
gate_inputs = math_ops.matmul(
|
||||
array_ops.concat([inputs, state], 1), self._gate_kernel)
|
||||
gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
|
||||
|
||||
value = math_ops.sigmoid(gate_inputs)
|
||||
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
|
||||
|
||||
candidate = nn_ops.bias_add(
|
||||
math_ops.matmul(inputs, self._candidate_input_kernel),
|
||||
self._candidate_input_bias)
|
||||
candidate += r * nn_ops.bias_add(
|
||||
math_ops.matmul(state, self._candidate_hidden_kernel),
|
||||
self._candidate_hidden_bias)
|
||||
candidate = self._activation(candidate)
|
||||
new_h = (1-u) * candidate + u * state
|
||||
return new_h, new_h
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user