Fix CuDNNCompatibleGRU after GRUCell refactorization

PiperOrigin-RevId: 175574730
This commit is contained in:
James Qin 2017-11-13 12:59:04 -08:00 committed by TensorFlower Gardener
parent 333bdea952
commit 90222dd7b2

View File

@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import common_shapes
@ -29,6 +28,7 @@ from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
@ -55,6 +55,11 @@ CUDNN_INPUT_LINEAR_MODE = "linear_input"
CUDNN_INPUT_SKIP_MODE = "skip_input"
CUDNN_INPUT_AUTO_MODE = "auto_select"
# pylint:disable=protected-access
_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME
_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME
# pylint:enable=protected-access
class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell):
"""Cudnn Compatible LSTMCell.
@ -87,9 +92,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell):
Cudnn compatible GRU (from Cudnn library user guide):
```python
r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate
i_t = sigma(x_t * W_i + h_t-1 * R_i + b_Wi + b_Ru) # update gate
u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate
h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate
h_t = (1 - i_t) .* h'_t + i_t .* h_t-1
h_t = (1 - u_t) .* h'_t + u_t .* h_t-1
```
Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}):
@ -112,33 +117,65 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell):
reuse=reuse,
kernel_initializer=kernel_initializer)
def build(self, inputs_shape):
if inputs_shape[1].value is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
input_depth = inputs_shape[1].value
self._gate_kernel = self.add_variable(
"gates/%s" % _WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, 2 * self._num_units],
initializer=self._kernel_initializer)
self._gate_bias = self.add_variable(
"gates/%s" % _BIAS_VARIABLE_NAME,
shape=[2 * self._num_units],
initializer=(
self._bias_initializer
if self._bias_initializer is not None
else init_ops.constant_initializer(1.0, dtype=self.dtype)))
self._candidate_input_kernel = self.add_variable(
"candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME,
shape=[input_depth, self._num_units],
initializer=self._kernel_initializer)
self._candidate_hidden_kernel = self.add_variable(
"candidate/hidden_projection/%s" % _WEIGHTS_VARIABLE_NAME,
shape=[self._num_units, self._num_units],
initializer=self._kernel_initializer)
self._candidate_input_bias = self.add_variable(
"candidate/input_projection/%s" % _BIAS_VARIABLE_NAME,
shape=[self._num_units],
initializer=(
self._bias_initializer
if self._bias_initializer is not None
else init_ops.zeros_initializer(dtype=self.dtype)))
self._candidate_hidden_bias = self.add_variable(
"candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME,
shape=[self._num_units],
initializer=(
self._bias_initializer
if self._bias_initializer is not None
else init_ops.zeros_initializer(dtype=self.dtype)))
def call(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""
with vs.variable_scope("gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
bias_ones = self._bias_initializer
if self._bias_initializer is None:
dtype = inputs.dtype
bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
# pylint: disable=protected-access
value = math_ops.sigmoid(
core_rnn_cell._linear([inputs, state], 2 * self._num_units, True,
bias_ones, self._kernel_initializer))
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
# pylint: enable=protected-access
with vs.variable_scope("candidate"):
# pylint: disable=protected-access
with vs.variable_scope("input_projection"):
hi = core_rnn_cell._linear(inputs, self._num_units, True,
self._bias_initializer,
self._kernel_initializer)
with vs.variable_scope("hidden_projection"):
hh = r * (core_rnn_cell._linear(state, self._num_units, True,
self._bias_initializer,
self._kernel_initializer))
# pylint: enable=protected-access
c = self._activation(hi + hh)
new_h = u * state + (1 - u) * c
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._gate_kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
value = math_ops.sigmoid(gate_inputs)
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
candidate = nn_ops.bias_add(
math_ops.matmul(inputs, self._candidate_input_kernel),
self._candidate_input_bias)
candidate += r * nn_ops.bias_add(
math_ops.matmul(state, self._candidate_hidden_kernel),
self._candidate_hidden_bias)
candidate = self._activation(candidate)
new_h = (1-u) * candidate + u * state
return new_h, new_h