Add DeviceWrapper and ResidualWrapper to tf.contrib.rnn.

Change: 145477306
This commit is contained in:
Eugene Brevdo 2017-01-24 15:37:49 -08:00 committed by TensorFlower Gardener
parent f9006d72eb
commit 13c45b266a
4 changed files with 94 additions and 0 deletions

View File

@ -36,6 +36,8 @@
@@EmbeddingWrapper
@@InputProjectionWrapper
@@OutputProjectionWrapper
@@DeviceWrapper
@@ResidualWrapper
### Block RNNCells
@@LSTMBlockCell
@ -76,6 +78,7 @@ from tensorflow.contrib.rnn.python.ops.core_rnn import static_state_saving_rnn
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicLSTMCell
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DeviceWrapper
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DropoutWrapper
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import EmbeddingWrapper
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import GRUCell
@ -84,6 +87,7 @@ from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMCell
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import MultiRNNCell
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import ResidualWrapper
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell
# pylint: disable=unused-import,wildcard-import, line-too-long

View File

@ -312,6 +312,36 @@ class RNNCellTest(test.TestCase):
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
def testResidualWrapper(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
base_cell = core_rnn_cell_impl.GRUCell(3)
g, m_new = base_cell(x, m)
variable_scope.get_variable_scope().reuse_variables()
g_res, m_new_res = core_rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g, g_res, m_new, m_new_res], {
x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.1, 0.1]])
})
# Residual connections
self.assertAllClose(res[1], res[0] + [1., 1., 1.])
# States are left untouched
self.assertAllClose(res[2], res[3])
def testDeviceWrapper(self):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
cell = core_rnn_cell_impl.DeviceWrapper(
core_rnn_cell_impl.GRUCell(3), "/cpu:14159")
outputs, _ = cell(x, m)
self.assertTrue("cpu:14159" in outputs.device.lower())
def testDropoutWrapper(self):
with self.test_session() as sess:
with variable_scope.variable_scope(

View File

@ -37,6 +37,8 @@
@@EmbeddingWrapper
@@InputProjectionWrapper
@@OutputProjectionWrapper
@@DeviceWrapper
@@ResidualWrapper
"""
from __future__ import absolute_import
from __future__ import division

View File

@ -528,6 +528,60 @@ class DropoutWrapper(RNNCell):
return output, new_state
class ResidualWrapper(RNNCell):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
def __init__(self, cell):
"""Constructs a `ResidualWrapper` for `cell`.
Args:
cell: An instance of `RNNCell`.
"""
self._cell = cell
def __call__(self, inputs, state, scope=None):
"""Run the cell and add its inputs to its outputs.
Args:
inputs: cell inputs.
state: cell state.
scope: optional cell scope.
Returns:
Tuple of cell outputs and new state.
Raises:
TypeError: If cell inputs and outputs have different structure (type).
ValueError: If cell inputs and outputs have different structure (value).
"""
output, new_state = self._cell(inputs, state, scope=scope)
nest.assert_same_structure(inputs, output)
res_output = nest.map_structure(
lambda inp, out: inp + out, inputs, output)
return (res_output, new_state)
class DeviceWrapper(RNNCell):
"""Operator that ensures an RNNCell runs on a particular device."""
def __init__(self, cell, device):
"""Construct a `DeviceWrapper` for `cell` with device `device`.
Ensures the wrapped `cell` is called with `tf.device(device)`.
Args:
cell: An instance of `RNNCell`.
device: A device string or function, for passing to `tf.device`.
"""
self._cell = cell
self._device = device
def __call__(self, inputs, state, scope=None):
"""Run the cell on specified device."""
with ops.device(self._device):
return self._cell(inputs, state, scope=scope)
class EmbeddingWrapper(RNNCell):
"""Operator adding input embedding to the given cell.
@ -615,6 +669,10 @@ class MultiRNNCell(RNNCell):
"""
if not cells:
raise ValueError("Must specify at least one cell for MultiRNNCell.")
if not nest.is_sequence(cells):
raise TypeError(
"cells must be a list or tuple, but saw: %s." % cells)
self._cells = cells
self._state_is_tuple = state_is_tuple
if not state_is_tuple: