diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 9f748996934..6c526b2c756 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops -from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes @@ -29,6 +28,7 @@ from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs @@ -55,6 +55,11 @@ CUDNN_INPUT_LINEAR_MODE = "linear_input" CUDNN_INPUT_SKIP_MODE = "skip_input" CUDNN_INPUT_AUTO_MODE = "auto_select" +# pylint:disable=protected-access +_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME +_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME +# pylint:enable=protected-access + class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): """Cudnn Compatible LSTMCell. @@ -87,9 +92,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): Cudnn compatible GRU (from Cudnn library user guide): ```python r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate - i_t = sigma(x_t * W_i + h_t-1 * R_i + b_Wi + b_Ru) # update gate + u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate - h_t = (1 - i_t) .* h'_t + i_t .* h_t-1 + h_t = (1 - u_t) .* h'_t + u_t .* h_t-1 ``` Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}): @@ -112,33 +117,65 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): reuse=reuse, kernel_initializer=kernel_initializer) + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" + % inputs_shape) + + input_depth = inputs_shape[1].value + self._gate_kernel = self.add_variable( + "gates/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, 2 * self._num_units], + initializer=self._kernel_initializer) + self._gate_bias = self.add_variable( + "gates/%s" % _BIAS_VARIABLE_NAME, + shape=[2 * self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.constant_initializer(1.0, dtype=self.dtype))) + + self._candidate_input_kernel = self.add_variable( + "candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth, self._num_units], + initializer=self._kernel_initializer) + self._candidate_hidden_kernel = self.add_variable( + "candidate/hidden_projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[self._num_units, self._num_units], + initializer=self._kernel_initializer) + + self._candidate_input_bias = self.add_variable( + "candidate/input_projection/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + self._candidate_hidden_bias = self.add_variable( + "candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" - with vs.variable_scope("gates"): # Reset gate and update gate. - # We start with bias of 1.0 to not reset and not update. - bias_ones = self._bias_initializer - if self._bias_initializer is None: - dtype = inputs.dtype - bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) - # pylint: disable=protected-access - value = math_ops.sigmoid( - core_rnn_cell._linear([inputs, state], 2 * self._num_units, True, - bias_ones, self._kernel_initializer)) - r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) - # pylint: enable=protected-access - with vs.variable_scope("candidate"): - # pylint: disable=protected-access - with vs.variable_scope("input_projection"): - hi = core_rnn_cell._linear(inputs, self._num_units, True, - self._bias_initializer, - self._kernel_initializer) - with vs.variable_scope("hidden_projection"): - hh = r * (core_rnn_cell._linear(state, self._num_units, True, - self._bias_initializer, - self._kernel_initializer)) - # pylint: enable=protected-access - c = self._activation(hi + hh) - new_h = u * state + (1 - u) * c + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, state], 1), self._gate_kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) + + value = math_ops.sigmoid(gate_inputs) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + + candidate = nn_ops.bias_add( + math_ops.matmul(inputs, self._candidate_input_kernel), + self._candidate_input_bias) + candidate += r * nn_ops.bias_add( + math_ops.matmul(state, self._candidate_hidden_kernel), + self._candidate_hidden_bias) + candidate = self._activation(candidate) + new_h = (1-u) * candidate + u * state return new_h, new_h