When using static_state_saving_rnn(..) in the following manner

_, state = tf.nn.static_state_saving_rnn(..)

the runtime will be blocked after some time, because the save_state method of the state_saver object won't be executed as a part of the graph (that part depends only on output node in the current implementation).
Now it should depend on state as well, so the above implementation won't be blocked.

PiperOrigin-RevId: 196024050
This commit is contained in:
A. Unique TensorFlower 2018-05-09 13:55:20 -07:00 committed by TensorFlower Gardener
parent 705550357f
commit ec0ef29835
2 changed files with 114 additions and 26 deletions

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
@ -142,6 +143,47 @@ class TestStateSaver(object):
self.saved_state[name] = state
return array_ops.identity(state)
@property
def batch_size(self):
return self._batch_size
@property
def state_size(self):
return self._state_size
class TestStateSaverWithCounters(TestStateSaver):
"""Class wrapper around TestStateSaver.
A dummy class used for testing of static_state_saving_rnn. It helps test if
save_state and state functions got called same number of time when we
evaluate output of rnn cell and state or either of them separately. It
inherits from the TestStateSaver and adds the counters for calls of functions.
"""
def __init__(self, batch_size, state_size):
super(TestStateSaverWithCounters, self).__init__(batch_size, state_size)
self._num_state_calls = variables_lib.Variable(0)
self._num_save_state_calls = variables_lib.Variable(0)
def state(self, name):
with ops_lib.control_dependencies(
[state_ops.assign_add(self._num_state_calls, 1)]):
return super(TestStateSaverWithCounters, self).state(name)
def save_state(self, name, state):
with ops_lib.control_dependencies([state_ops.assign_add(
self._num_save_state_calls, 1)]):
return super(TestStateSaverWithCounters, self).save_state(name, state)
@property
def num_state_calls(self):
return self._num_state_calls
@property
def num_save_state_calls(self):
return self._num_save_state_calls
class RNNTest(test.TestCase):
@ -1792,13 +1834,40 @@ class StateSaverRNNTest(test.TestCase):
self._seed = 23489
np.random.seed(self._seed)
def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
def _factory(self, scope, state_saver):
num_units = state_saver.state_size // 2
batch_size = state_saver.batch_size
input_size = 5
max_length = 8
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=False,
initializer=initializer,
state_is_tuple=False)
inputs = max_length * [
array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size))
]
out, state = rnn.static_state_saving_rnn(
cell,
inputs,
state_saver=state_saver,
state_name="save_lstm",
scope=scope)
return out, state, state_saver
def _testScope(self, prefix="prefix", use_outer_scope=True):
num_units = 3
batch_size = 2
state_saver = TestStateSaver(batch_size, 2 * num_units)
with self.test_session(use_gpu=True, graph=ops_lib.Graph()):
if use_outer_scope:
with variable_scope.variable_scope(prefix) as scope:
factory(scope)
self._factory(scope=scope, state_saver=state_saver)
else:
factory(prefix)
self._factory(scope=prefix, state_saver=state_saver)
variables_lib.global_variables_initializer()
# check that all the variables names starts
@ -1813,34 +1882,46 @@ class StateSaverRNNTest(test.TestCase):
self.assertEqual(len(scope_vars), len(all_vars))
def testStateSaverRNNScope(self):
self._testScope(use_outer_scope=True)
self._testScope(use_outer_scope=False)
self._testScope(prefix=None, use_outer_scope=False)
def testStateSaverCallsSaveState(self):
"""Test that number of calls to state and save_state is equal.
Test if the order of actual evaluating or skipping evaluation of out,
state tensors, which are the output tensors from static_state_saving_rnn,
have influence on number of calls to save_state and state methods of
state_saver object (the number of calls should be same.)
"""
num_units = 3
input_size = 5
batch_size = 2
max_length = 8
state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
def factory(scope):
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, 2 * num_units)
cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=False,
initializer=initializer,
state_is_tuple=False)
inputs = max_length * [
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
]
return rnn.static_state_saving_rnn(
cell,
inputs,
state_saver=state_saver,
state_name="save_lstm",
scope=scope)
with self.test_session() as sess:
sess.run(variables_lib.global_variables_initializer())
sess.run(variables_lib.local_variables_initializer())
self._testScope(factory, use_outer_scope=True)
self._testScope(factory, use_outer_scope=False)
self._testScope(factory, prefix=None, use_outer_scope=False)
_, _, num_state_calls, num_save_state_calls = sess.run([
out,
state,
state_saver.num_state_calls,
state_saver.num_save_state_calls])
self.assertEqual(num_state_calls, num_save_state_calls)
_, num_state_calls, num_save_state_calls = sess.run([
out,
state_saver.num_state_calls,
state_saver.num_save_state_calls])
self.assertEqual(num_state_calls, num_save_state_calls)
_, num_state_calls, num_save_state_calls = sess.run([
state,
state_saver.num_state_calls,
state_saver.num_save_state_calls])
self.assertEqual(num_state_calls, num_save_state_calls)
class GRUTest(test.TestCase):

View File

@ -1401,6 +1401,13 @@ def static_state_saving_rnn(cell,
outputs[-1] = nest.pack_sequence_as(
structure=last_output, flat_sequence=flat_last_output)
if state_is_tuple:
state = nest.pack_sequence_as(
structure=state,
flat_sequence=[array_ops.identity(s) for s in flat_state])
else:
state = array_ops.identity(state)
return (outputs, state)