Add DeviceWrapper and ResidualWrapper to tf.contrib.rnn.
Change: 145477306
This commit is contained in:
parent
f9006d72eb
commit
13c45b266a
@ -36,6 +36,8 @@
|
|||||||
@@EmbeddingWrapper
|
@@EmbeddingWrapper
|
||||||
@@InputProjectionWrapper
|
@@InputProjectionWrapper
|
||||||
@@OutputProjectionWrapper
|
@@OutputProjectionWrapper
|
||||||
|
@@DeviceWrapper
|
||||||
|
@@ResidualWrapper
|
||||||
|
|
||||||
### Block RNNCells
|
### Block RNNCells
|
||||||
@@LSTMBlockCell
|
@@LSTMBlockCell
|
||||||
@ -76,6 +78,7 @@ from tensorflow.contrib.rnn.python.ops.core_rnn import static_state_saving_rnn
|
|||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicLSTMCell
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicLSTMCell
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
|
||||||
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DeviceWrapper
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DropoutWrapper
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import DropoutWrapper
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import EmbeddingWrapper
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import EmbeddingWrapper
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import GRUCell
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import GRUCell
|
||||||
@ -84,6 +87,7 @@ from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMCell
|
|||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import MultiRNNCell
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import MultiRNNCell
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
|
||||||
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import ResidualWrapper
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell
|
||||||
|
|
||||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||||
|
@ -312,6 +312,36 @@ class RNNCellTest(test.TestCase):
|
|||||||
# The numbers in results were not calculated, this is just a smoke test.
|
# The numbers in results were not calculated, this is just a smoke test.
|
||||||
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
|
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
|
||||||
|
|
||||||
|
def testResidualWrapper(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
|
x = array_ops.zeros([1, 3])
|
||||||
|
m = array_ops.zeros([1, 3])
|
||||||
|
base_cell = core_rnn_cell_impl.GRUCell(3)
|
||||||
|
g, m_new = base_cell(x, m)
|
||||||
|
variable_scope.get_variable_scope().reuse_variables()
|
||||||
|
g_res, m_new_res = core_rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
|
||||||
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
|
res = sess.run([g, g_res, m_new, m_new_res], {
|
||||||
|
x: np.array([[1., 1., 1.]]),
|
||||||
|
m: np.array([[0.1, 0.1, 0.1]])
|
||||||
|
})
|
||||||
|
# Residual connections
|
||||||
|
self.assertAllClose(res[1], res[0] + [1., 1., 1.])
|
||||||
|
# States are left untouched
|
||||||
|
self.assertAllClose(res[2], res[3])
|
||||||
|
|
||||||
|
def testDeviceWrapper(self):
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||||
|
x = array_ops.zeros([1, 3])
|
||||||
|
m = array_ops.zeros([1, 3])
|
||||||
|
cell = core_rnn_cell_impl.DeviceWrapper(
|
||||||
|
core_rnn_cell_impl.GRUCell(3), "/cpu:14159")
|
||||||
|
outputs, _ = cell(x, m)
|
||||||
|
self.assertTrue("cpu:14159" in outputs.device.lower())
|
||||||
|
|
||||||
def testDropoutWrapper(self):
|
def testDropoutWrapper(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
|
@ -37,6 +37,8 @@
|
|||||||
@@EmbeddingWrapper
|
@@EmbeddingWrapper
|
||||||
@@InputProjectionWrapper
|
@@InputProjectionWrapper
|
||||||
@@OutputProjectionWrapper
|
@@OutputProjectionWrapper
|
||||||
|
@@DeviceWrapper
|
||||||
|
@@ResidualWrapper
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -528,6 +528,60 @@ class DropoutWrapper(RNNCell):
|
|||||||
return output, new_state
|
return output, new_state
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualWrapper(RNNCell):
|
||||||
|
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
|
||||||
|
|
||||||
|
def __init__(self, cell):
|
||||||
|
"""Constructs a `ResidualWrapper` for `cell`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: An instance of `RNNCell`.
|
||||||
|
"""
|
||||||
|
self._cell = cell
|
||||||
|
|
||||||
|
def __call__(self, inputs, state, scope=None):
|
||||||
|
"""Run the cell and add its inputs to its outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: cell inputs.
|
||||||
|
state: cell state.
|
||||||
|
scope: optional cell scope.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of cell outputs and new state.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If cell inputs and outputs have different structure (type).
|
||||||
|
ValueError: If cell inputs and outputs have different structure (value).
|
||||||
|
"""
|
||||||
|
output, new_state = self._cell(inputs, state, scope=scope)
|
||||||
|
nest.assert_same_structure(inputs, output)
|
||||||
|
res_output = nest.map_structure(
|
||||||
|
lambda inp, out: inp + out, inputs, output)
|
||||||
|
return (res_output, new_state)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceWrapper(RNNCell):
|
||||||
|
"""Operator that ensures an RNNCell runs on a particular device."""
|
||||||
|
|
||||||
|
def __init__(self, cell, device):
|
||||||
|
"""Construct a `DeviceWrapper` for `cell` with device `device`.
|
||||||
|
|
||||||
|
Ensures the wrapped `cell` is called with `tf.device(device)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: An instance of `RNNCell`.
|
||||||
|
device: A device string or function, for passing to `tf.device`.
|
||||||
|
"""
|
||||||
|
self._cell = cell
|
||||||
|
self._device = device
|
||||||
|
|
||||||
|
def __call__(self, inputs, state, scope=None):
|
||||||
|
"""Run the cell on specified device."""
|
||||||
|
with ops.device(self._device):
|
||||||
|
return self._cell(inputs, state, scope=scope)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingWrapper(RNNCell):
|
class EmbeddingWrapper(RNNCell):
|
||||||
"""Operator adding input embedding to the given cell.
|
"""Operator adding input embedding to the given cell.
|
||||||
|
|
||||||
@ -615,6 +669,10 @@ class MultiRNNCell(RNNCell):
|
|||||||
"""
|
"""
|
||||||
if not cells:
|
if not cells:
|
||||||
raise ValueError("Must specify at least one cell for MultiRNNCell.")
|
raise ValueError("Must specify at least one cell for MultiRNNCell.")
|
||||||
|
if not nest.is_sequence(cells):
|
||||||
|
raise TypeError(
|
||||||
|
"cells must be a list or tuple, but saw: %s." % cells)
|
||||||
|
|
||||||
self._cells = cells
|
self._cells = cells
|
||||||
self._state_is_tuple = state_is_tuple
|
self._state_is_tuple = state_is_tuple
|
||||||
if not state_is_tuple:
|
if not state_is_tuple:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user