From 90222dd7b29ff2597bc7f8d0f92db17324f591b0 Mon Sep 17 00:00:00 2001
From: James Qin <jamesqin@google.com>
Date: Mon, 13 Nov 2017 12:59:04 -0800
Subject: [PATCH] Fix CuDNNCompatibleGRU after GRUCell refactorization

PiperOrigin-RevId: 175574730
---
 .../cudnn_rnn/python/ops/cudnn_rnn_ops.py     | 93 +++++++++++++------
 1 file changed, 65 insertions(+), 28 deletions(-)

diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 9f748996934..6c526b2c756 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -18,7 +18,6 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell
 from tensorflow.contrib.rnn.python.ops import lstm_ops
 from tensorflow.contrib.util import loader
 from tensorflow.python.framework import common_shapes
@@ -29,6 +28,7 @@ from tensorflow.python.layers import base as base_layer
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import rnn_cell_impl
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope as vs
@@ -55,6 +55,11 @@ CUDNN_INPUT_LINEAR_MODE = "linear_input"
 CUDNN_INPUT_SKIP_MODE = "skip_input"
 CUDNN_INPUT_AUTO_MODE = "auto_select"
 
+# pylint:disable=protected-access
+_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME
+_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME
+# pylint:enable=protected-access
+
 
 class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell):
   """Cudnn Compatible LSTMCell.
@@ -87,9 +92,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell):
   Cudnn compatible GRU (from Cudnn library user guide):
   ```python
   r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr)  # reset gate
-  i_t = sigma(x_t * W_i + h_t-1 * R_i + b_Wi + b_Ru)  # update gate
+  u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru)  # update gate
   h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh)  # new memory gate
-  h_t = (1 - i_t) .* h'_t + i_t .* h_t-1
+  h_t = (1 - u_t) .* h'_t + u_t .* h_t-1
   ```
 
   Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}):
@@ -112,33 +117,65 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell):
         reuse=reuse,
         kernel_initializer=kernel_initializer)
 
+  def build(self, inputs_shape):
+    if inputs_shape[1].value is None:
+      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
+                       % inputs_shape)
+
+    input_depth = inputs_shape[1].value
+    self._gate_kernel = self.add_variable(
+        "gates/%s" % _WEIGHTS_VARIABLE_NAME,
+        shape=[input_depth + self._num_units, 2 * self._num_units],
+        initializer=self._kernel_initializer)
+    self._gate_bias = self.add_variable(
+        "gates/%s" % _BIAS_VARIABLE_NAME,
+        shape=[2 * self._num_units],
+        initializer=(
+            self._bias_initializer
+            if self._bias_initializer is not None
+            else init_ops.constant_initializer(1.0, dtype=self.dtype)))
+
+    self._candidate_input_kernel = self.add_variable(
+        "candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME,
+        shape=[input_depth, self._num_units],
+        initializer=self._kernel_initializer)
+    self._candidate_hidden_kernel = self.add_variable(
+        "candidate/hidden_projection/%s" % _WEIGHTS_VARIABLE_NAME,
+        shape=[self._num_units, self._num_units],
+        initializer=self._kernel_initializer)
+
+    self._candidate_input_bias = self.add_variable(
+        "candidate/input_projection/%s" % _BIAS_VARIABLE_NAME,
+        shape=[self._num_units],
+        initializer=(
+            self._bias_initializer
+            if self._bias_initializer is not None
+            else init_ops.zeros_initializer(dtype=self.dtype)))
+    self._candidate_hidden_bias = self.add_variable(
+        "candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME,
+        shape=[self._num_units],
+        initializer=(
+            self._bias_initializer
+            if self._bias_initializer is not None
+            else init_ops.zeros_initializer(dtype=self.dtype)))
+
   def call(self, inputs, state):
     """Gated recurrent unit (GRU) with nunits cells."""
-    with vs.variable_scope("gates"):  # Reset gate and update gate.
-      # We start with bias of 1.0 to not reset and not update.
-      bias_ones = self._bias_initializer
-      if self._bias_initializer is None:
-        dtype = inputs.dtype
-        bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
-      # pylint: disable=protected-access
-      value = math_ops.sigmoid(
-          core_rnn_cell._linear([inputs, state], 2 * self._num_units, True,
-                                bias_ones, self._kernel_initializer))
-      r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
-      # pylint: enable=protected-access
-    with vs.variable_scope("candidate"):
-      # pylint: disable=protected-access
-      with vs.variable_scope("input_projection"):
-        hi = core_rnn_cell._linear(inputs, self._num_units, True,
-                                   self._bias_initializer,
-                                   self._kernel_initializer)
-      with vs.variable_scope("hidden_projection"):
-        hh = r * (core_rnn_cell._linear(state, self._num_units, True,
-                                        self._bias_initializer,
-                                        self._kernel_initializer))
-      # pylint: enable=protected-access
-      c = self._activation(hi + hh)
-    new_h = u * state + (1 - u) * c
+    gate_inputs = math_ops.matmul(
+        array_ops.concat([inputs, state], 1), self._gate_kernel)
+    gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
+
+    value = math_ops.sigmoid(gate_inputs)
+    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
+
+    candidate = nn_ops.bias_add(
+        math_ops.matmul(inputs, self._candidate_input_kernel),
+        self._candidate_input_bias)
+    candidate += r * nn_ops.bias_add(
+        math_ops.matmul(state, self._candidate_hidden_kernel),
+        self._candidate_hidden_bias)
+    candidate = self._activation(candidate)
+    new_h = (1-u) * candidate + u * state
     return new_h, new_h