Add DeviceWrapper and ResidualWrapper to tf.contrib.rnn.
Change: 145477306
This commit is contained in:
parent
f9006d72eb
commit
13c45b266a
@ -36,6 +36,8 @@
|
||||
@@EmbeddingWrapper
|
||||
@@InputProjectionWrapper
|
||||
@@OutputProjectionWrapper
|
||||
@@DeviceWrapper
|
||||
@@ResidualWrapper
|
||||
|
||||
### Block RNNCells
|
||||
@@LSTMBlockCell
|
||||
@ -76,6 +78,7 @@ from tensorflow.contrib.rnn.python.ops.core_rnn import static_state_saving_rnn
|
||||
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicLSTMCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DeviceWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DropoutWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import EmbeddingWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import GRUCell
|
||||
@ -84,6 +87,7 @@ from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import MultiRNNCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import ResidualWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||
|
@ -312,6 +312,36 @@ class RNNCellTest(test.TestCase):
|
||||
# The numbers in results were not calculated, this is just a smoke test.
|
||||
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
|
||||
|
||||
def testResidualWrapper(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = array_ops.zeros([1, 3])
|
||||
base_cell = core_rnn_cell_impl.GRUCell(3)
|
||||
g, m_new = base_cell(x, m)
|
||||
variable_scope.get_variable_scope().reuse_variables()
|
||||
g_res, m_new_res = core_rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
|
||||
sess.run([variables_lib.global_variables_initializer()])
|
||||
res = sess.run([g, g_res, m_new, m_new_res], {
|
||||
x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.1, 0.1]])
|
||||
})
|
||||
# Residual connections
|
||||
self.assertAllClose(res[1], res[0] + [1., 1., 1.])
|
||||
# States are left untouched
|
||||
self.assertAllClose(res[2], res[3])
|
||||
|
||||
def testDeviceWrapper(self):
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = array_ops.zeros([1, 3])
|
||||
cell = core_rnn_cell_impl.DeviceWrapper(
|
||||
core_rnn_cell_impl.GRUCell(3), "/cpu:14159")
|
||||
outputs, _ = cell(x, m)
|
||||
self.assertTrue("cpu:14159" in outputs.device.lower())
|
||||
|
||||
def testDropoutWrapper(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope(
|
||||
|
@ -37,6 +37,8 @@
|
||||
@@EmbeddingWrapper
|
||||
@@InputProjectionWrapper
|
||||
@@OutputProjectionWrapper
|
||||
@@DeviceWrapper
|
||||
@@ResidualWrapper
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -528,6 +528,60 @@ class DropoutWrapper(RNNCell):
|
||||
return output, new_state
|
||||
|
||||
|
||||
class ResidualWrapper(RNNCell):
|
||||
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
|
||||
|
||||
def __init__(self, cell):
|
||||
"""Constructs a `ResidualWrapper` for `cell`.
|
||||
|
||||
Args:
|
||||
cell: An instance of `RNNCell`.
|
||||
"""
|
||||
self._cell = cell
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Run the cell and add its inputs to its outputs.
|
||||
|
||||
Args:
|
||||
inputs: cell inputs.
|
||||
state: cell state.
|
||||
scope: optional cell scope.
|
||||
|
||||
Returns:
|
||||
Tuple of cell outputs and new state.
|
||||
|
||||
Raises:
|
||||
TypeError: If cell inputs and outputs have different structure (type).
|
||||
ValueError: If cell inputs and outputs have different structure (value).
|
||||
"""
|
||||
output, new_state = self._cell(inputs, state, scope=scope)
|
||||
nest.assert_same_structure(inputs, output)
|
||||
res_output = nest.map_structure(
|
||||
lambda inp, out: inp + out, inputs, output)
|
||||
return (res_output, new_state)
|
||||
|
||||
|
||||
class DeviceWrapper(RNNCell):
|
||||
"""Operator that ensures an RNNCell runs on a particular device."""
|
||||
|
||||
def __init__(self, cell, device):
|
||||
"""Construct a `DeviceWrapper` for `cell` with device `device`.
|
||||
|
||||
Ensures the wrapped `cell` is called with `tf.device(device)`.
|
||||
|
||||
Args:
|
||||
cell: An instance of `RNNCell`.
|
||||
device: A device string or function, for passing to `tf.device`.
|
||||
"""
|
||||
self._cell = cell
|
||||
self._device = device
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Run the cell on specified device."""
|
||||
with ops.device(self._device):
|
||||
return self._cell(inputs, state, scope=scope)
|
||||
|
||||
|
||||
class EmbeddingWrapper(RNNCell):
|
||||
"""Operator adding input embedding to the given cell.
|
||||
|
||||
@ -615,6 +669,10 @@ class MultiRNNCell(RNNCell):
|
||||
"""
|
||||
if not cells:
|
||||
raise ValueError("Must specify at least one cell for MultiRNNCell.")
|
||||
if not nest.is_sequence(cells):
|
||||
raise TypeError(
|
||||
"cells must be a list or tuple, but saw: %s." % cells)
|
||||
|
||||
self._cells = cells
|
||||
self._state_is_tuple = state_is_tuple
|
||||
if not state_is_tuple:
|
||||
|
Loading…
x
Reference in New Issue
Block a user