When using static_state_saving_rnn(..) in the following manner
_, state = tf.nn.static_state_saving_rnn(..) the runtime will be blocked after some time, because the save_state method of the state_saver object won't be executed as a part of the graph (that part depends only on output node in the current implementation). Now it should depend on state as well, so the above implementation won't be blocked. PiperOrigin-RevId: 196024050
This commit is contained in:
parent
705550357f
commit
ec0ef29835
@ -38,6 +38,7 @@ from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
@ -142,6 +143,47 @@ class TestStateSaver(object):
|
||||
self.saved_state[name] = state
|
||||
return array_ops.identity(state)
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._state_size
|
||||
|
||||
|
||||
class TestStateSaverWithCounters(TestStateSaver):
|
||||
"""Class wrapper around TestStateSaver.
|
||||
|
||||
A dummy class used for testing of static_state_saving_rnn. It helps test if
|
||||
save_state and state functions got called same number of time when we
|
||||
evaluate output of rnn cell and state or either of them separately. It
|
||||
inherits from the TestStateSaver and adds the counters for calls of functions.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, state_size):
|
||||
super(TestStateSaverWithCounters, self).__init__(batch_size, state_size)
|
||||
self._num_state_calls = variables_lib.Variable(0)
|
||||
self._num_save_state_calls = variables_lib.Variable(0)
|
||||
|
||||
def state(self, name):
|
||||
with ops_lib.control_dependencies(
|
||||
[state_ops.assign_add(self._num_state_calls, 1)]):
|
||||
return super(TestStateSaverWithCounters, self).state(name)
|
||||
|
||||
def save_state(self, name, state):
|
||||
with ops_lib.control_dependencies([state_ops.assign_add(
|
||||
self._num_save_state_calls, 1)]):
|
||||
return super(TestStateSaverWithCounters, self).save_state(name, state)
|
||||
|
||||
@property
|
||||
def num_state_calls(self):
|
||||
return self._num_state_calls
|
||||
|
||||
@property
|
||||
def num_save_state_calls(self):
|
||||
return self._num_save_state_calls
|
||||
|
||||
|
||||
class RNNTest(test.TestCase):
|
||||
|
||||
@ -1792,13 +1834,40 @@ class StateSaverRNNTest(test.TestCase):
|
||||
self._seed = 23489
|
||||
np.random.seed(self._seed)
|
||||
|
||||
def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
|
||||
def _factory(self, scope, state_saver):
|
||||
num_units = state_saver.state_size // 2
|
||||
batch_size = state_saver.batch_size
|
||||
input_size = 5
|
||||
max_length = 8
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=False,
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
inputs = max_length * [
|
||||
array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
out, state = rnn.static_state_saving_rnn(
|
||||
cell,
|
||||
inputs,
|
||||
state_saver=state_saver,
|
||||
state_name="save_lstm",
|
||||
scope=scope)
|
||||
return out, state, state_saver
|
||||
|
||||
def _testScope(self, prefix="prefix", use_outer_scope=True):
|
||||
num_units = 3
|
||||
batch_size = 2
|
||||
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
||||
|
||||
with self.test_session(use_gpu=True, graph=ops_lib.Graph()):
|
||||
if use_outer_scope:
|
||||
with variable_scope.variable_scope(prefix) as scope:
|
||||
factory(scope)
|
||||
self._factory(scope=scope, state_saver=state_saver)
|
||||
else:
|
||||
factory(prefix)
|
||||
self._factory(scope=prefix, state_saver=state_saver)
|
||||
variables_lib.global_variables_initializer()
|
||||
|
||||
# check that all the variables names starts
|
||||
@ -1813,34 +1882,46 @@ class StateSaverRNNTest(test.TestCase):
|
||||
self.assertEqual(len(scope_vars), len(all_vars))
|
||||
|
||||
def testStateSaverRNNScope(self):
|
||||
self._testScope(use_outer_scope=True)
|
||||
self._testScope(use_outer_scope=False)
|
||||
self._testScope(prefix=None, use_outer_scope=False)
|
||||
|
||||
def testStateSaverCallsSaveState(self):
|
||||
"""Test that number of calls to state and save_state is equal.
|
||||
|
||||
Test if the order of actual evaluating or skipping evaluation of out,
|
||||
state tensors, which are the output tensors from static_state_saving_rnn,
|
||||
have influence on number of calls to save_state and state methods of
|
||||
state_saver object (the number of calls should be same.)
|
||||
"""
|
||||
|
||||
num_units = 3
|
||||
input_size = 5
|
||||
batch_size = 2
|
||||
max_length = 8
|
||||
state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
|
||||
out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
|
||||
|
||||
def factory(scope):
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
state_saver = TestStateSaver(batch_size, 2 * num_units)
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=False,
|
||||
initializer=initializer,
|
||||
state_is_tuple=False)
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
return rnn.static_state_saving_rnn(
|
||||
cell,
|
||||
inputs,
|
||||
state_saver=state_saver,
|
||||
state_name="save_lstm",
|
||||
scope=scope)
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables_lib.global_variables_initializer())
|
||||
sess.run(variables_lib.local_variables_initializer())
|
||||
|
||||
self._testScope(factory, use_outer_scope=True)
|
||||
self._testScope(factory, use_outer_scope=False)
|
||||
self._testScope(factory, prefix=None, use_outer_scope=False)
|
||||
_, _, num_state_calls, num_save_state_calls = sess.run([
|
||||
out,
|
||||
state,
|
||||
state_saver.num_state_calls,
|
||||
state_saver.num_save_state_calls])
|
||||
self.assertEqual(num_state_calls, num_save_state_calls)
|
||||
|
||||
_, num_state_calls, num_save_state_calls = sess.run([
|
||||
out,
|
||||
state_saver.num_state_calls,
|
||||
state_saver.num_save_state_calls])
|
||||
self.assertEqual(num_state_calls, num_save_state_calls)
|
||||
|
||||
_, num_state_calls, num_save_state_calls = sess.run([
|
||||
state,
|
||||
state_saver.num_state_calls,
|
||||
state_saver.num_save_state_calls])
|
||||
self.assertEqual(num_state_calls, num_save_state_calls)
|
||||
|
||||
class GRUTest(test.TestCase):
|
||||
|
||||
|
@ -1401,6 +1401,13 @@ def static_state_saving_rnn(cell,
|
||||
outputs[-1] = nest.pack_sequence_as(
|
||||
structure=last_output, flat_sequence=flat_last_output)
|
||||
|
||||
if state_is_tuple:
|
||||
state = nest.pack_sequence_as(
|
||||
structure=state,
|
||||
flat_sequence=[array_ops.identity(s) for s in flat_state])
|
||||
else:
|
||||
state = array_ops.identity(state)
|
||||
|
||||
return (outputs, state)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user