Fix the output shape of functional_rnn for time-major inputs.

PiperOrigin-RevId: 207780606
This commit is contained in:
A. Unique TensorFlower 2018-08-07 14:31:49 -07:00 committed by TensorFlower Gardener
parent cf233f7281
commit 9b84c91b68
2 changed files with 149 additions and 18 deletions

View File

@ -61,10 +61,17 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
func, args = self._CELLDEFS[celldef_name]
return func(*args)
def _CreateInputs(self):
inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE,
FunctionalRnnTest._TOTAL_TIME,
FunctionalRnnTest._INPUT_SIZE])
def _CreateInputs(self, time_major=False):
if time_major:
inputs = np.random.random([
FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE,
FunctionalRnnTest._INPUT_SIZE
])
else:
inputs = np.random.random([
FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME,
FunctionalRnnTest._INPUT_SIZE
])
# Always leave one time slot empty, to check max_length behavior.
sequence_length = np.random.randint(
0, high=FunctionalRnnTest._TOTAL_TIME - 1,
@ -72,15 +79,51 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
dtype=np.int)
return (inputs, sequence_length)
def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs,
tf_sequence_length, initial_state=None,
time_major=None, scope=None):
tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs,
sequence_length=tf_sequence_length,
initial_state=initial_state,
dtype=dtypes.float32,
time_major=time_major,
scope=scope)
def _CreateSymmetricInputs(self):
# total time = batch size
inputs = np.zeros(
(FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE,
FunctionalRnnTest._INPUT_SIZE))
for i in range(FunctionalRnnTest._BATCH_SIZE):
for j in range(i, FunctionalRnnTest._BATCH_SIZE):
inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE])
inputs[j][i] = inputs[i][j]
# Always leave one time slot empty, to check max_length behavior.
sequence_length = np.random.randint(
0,
high=FunctionalRnnTest._BATCH_SIZE - 1,
size=FunctionalRnnTest._BATCH_SIZE,
dtype=np.int)
return (inputs, sequence_length)
def _CreateRnnGraph(self,
create_rnn_computation_func,
cell,
tf_inputs,
tf_sequence_length,
is_bidirectional,
initial_state=None,
time_major=None,
scope=None):
if is_bidirectional:
tf_result = create_rnn_computation_func(
cell_fw=cell,
cell_bw=cell,
inputs=tf_inputs,
sequence_length=tf_sequence_length,
dtype=dtypes.float32,
time_major=time_major,
scope=scope)
else:
tf_result = create_rnn_computation_func(
cell=cell,
inputs=tf_inputs,
sequence_length=tf_sequence_length,
initial_state=initial_state,
dtype=dtypes.float32,
time_major=time_major,
scope=scope)
grad = gradients_impl.gradients(tf_result, variables.trainable_variables())
return {'inference': tf_result, 'grad': grad}
@ -102,15 +145,26 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
variable_cache[n] = v
def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache,
is_dynamic):
is_dynamic, time_major=None, is_bidirectional=False):
with ops.Graph().as_default() as graph:
tf_inputs = array_ops.placeholder(
dtypes.float32, shape=numpy_inputs.shape)
tf_slen = array_ops.placeholder(dtypes.int32)
feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen}
cell = self._CreateCell(cell_name)
fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn
fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen)
if is_dynamic:
if is_bidirectional:
fn = rnn_lib.bidirectional_dynamic_rnn
else:
fn = rnn_lib.dynamic_rnn
else:
if is_bidirectional:
fn = functional_rnn.bidirectional_functional_rnn
else:
fn = functional_rnn.functional_rnn
fetches = self._CreateRnnGraph(
fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major)
with self.test_session(graph=graph) as sess:
sess.run(variables.global_variables_initializer())
# Note that cell.trainable_variables it not always set.
@ -158,6 +212,78 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
def testLstmWithTimeMajorInputs(self):
"""Checks an LSTM against the reference implementation, with time_major."""
time_major = True
np_inputs, np_slen = self._CreateInputs(time_major=True)
var_cache = {}
args = [np_inputs, np_slen, 'lstm', var_cache]
_, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major)
_, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major)
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
def testBidirectionalLstmWithTimeMajorInputs(self):
"""Checks a bi-directional LSTM with time-major inputs."""
time_major = True
np_inputs, np_slen = self._CreateInputs(time_major)
var_cache = {}
args = [np_inputs, np_slen, 'lstm', var_cache]
_, func_rnn = self._RunRnn(
*(args + [False]), time_major=time_major, is_bidirectional=True)
_, dyn_rnn = self._RunRnn(
*(args + [True]), time_major=time_major, is_bidirectional=True)
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
# TODO(b/112170761): comment out this line after the bug is fixed.
# self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
def testBidirectionalLstm(self):
"""Checks time-major and batch-major rnn produce consistent results."""
time_major_inputs, np_slen = self._CreateInputs(True)
batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2])
var_cache = {}
args = [np_slen, 'lstm', var_cache, False]
_, time_major_rnn = self._RunRnn(
*([time_major_inputs] + args), time_major=True, is_bidirectional=True)
_, batch_major_rnn = self._RunRnn(
*([batch_major_inputs]+ args), time_major=False, is_bidirectional=True)
# Convert the batch-major outputs to be time-major before the comparasion.
outputs, state = batch_major_rnn['inference']
outputs = [np.transpose(x, [1, 0, 2]) for x in outputs]
batch_major_rnn['inference'] = [outputs, state]
self.assertAllClose(time_major_rnn['inference'],
batch_major_rnn['inference'])
self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad'])
def testBidirectionalLstmWithSymmetricInputs(self):
"""Checks a bi-directional LSTM with symmetric inputs.
time-major and batch-major rnn produce the same result with symmetric
inputs.
"""
np_inputs, np_slen = self._CreateSymmetricInputs()
var_cache = {}
args = [np_inputs, np_slen, 'lstm', var_cache]
_, time_major_func_rnn = self._RunRnn(
*(args + [False]), time_major=True, is_bidirectional=True)
_, batch_major_func_rnn = self._RunRnn(
*(args + [False]), time_major=False, is_bidirectional=True)
_, time_major_dyn_rnn = self._RunRnn(
*(args + [True]), time_major=True, is_bidirectional=True)
_, batch_major_dyn_rnn = self._RunRnn(
*(args + [True]), time_major=False, is_bidirectional=True)
self.assertAllClose(time_major_func_rnn['inference'],
batch_major_func_rnn['inference'])
self.assertAllClose(time_major_func_rnn['grad'],
batch_major_func_rnn['grad'])
self.assertAllClose(time_major_dyn_rnn['inference'],
batch_major_dyn_rnn['inference'])
self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad'])
self.assertAllClose(time_major_func_rnn['inference'],
batch_major_dyn_rnn['inference'])
self.assertAllClose(time_major_func_rnn['grad'],
batch_major_dyn_rnn['grad'])
if __name__ == '__main__':
test_lib.main()

View File

@ -284,8 +284,13 @@ def functional_rnn(cell, inputs, sequence_length=None,
inputs=inputs,
cell_fn=func_cell.cell_step,
use_tpu=use_tpu)
return _PostProcessOutput(extended_acc_state, extended_final_state,
func_cell, inputs_flat[0].shape[0], sequence_length)
tf_output, tf_state = _PostProcessOutput(
extended_acc_state, extended_final_state, func_cell,
inputs_flat[0].shape[0], sequence_length)
if time_major:
tf_output = array_ops.transpose(tf_output, [1, 0, 2])
return tf_output, tf_state
def bidirectional_functional_rnn(