Fix the output shape of functional_rnn for time-major inputs.
PiperOrigin-RevId: 207780606
This commit is contained in:
parent
cf233f7281
commit
9b84c91b68
@ -61,10 +61,17 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
|
||||
func, args = self._CELLDEFS[celldef_name]
|
||||
return func(*args)
|
||||
|
||||
def _CreateInputs(self):
|
||||
inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE,
|
||||
FunctionalRnnTest._TOTAL_TIME,
|
||||
FunctionalRnnTest._INPUT_SIZE])
|
||||
def _CreateInputs(self, time_major=False):
|
||||
if time_major:
|
||||
inputs = np.random.random([
|
||||
FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE,
|
||||
FunctionalRnnTest._INPUT_SIZE
|
||||
])
|
||||
else:
|
||||
inputs = np.random.random([
|
||||
FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME,
|
||||
FunctionalRnnTest._INPUT_SIZE
|
||||
])
|
||||
# Always leave one time slot empty, to check max_length behavior.
|
||||
sequence_length = np.random.randint(
|
||||
0, high=FunctionalRnnTest._TOTAL_TIME - 1,
|
||||
@ -72,15 +79,51 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
|
||||
dtype=np.int)
|
||||
return (inputs, sequence_length)
|
||||
|
||||
def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs,
|
||||
tf_sequence_length, initial_state=None,
|
||||
time_major=None, scope=None):
|
||||
tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs,
|
||||
sequence_length=tf_sequence_length,
|
||||
initial_state=initial_state,
|
||||
dtype=dtypes.float32,
|
||||
time_major=time_major,
|
||||
scope=scope)
|
||||
def _CreateSymmetricInputs(self):
|
||||
# total time = batch size
|
||||
inputs = np.zeros(
|
||||
(FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE,
|
||||
FunctionalRnnTest._INPUT_SIZE))
|
||||
for i in range(FunctionalRnnTest._BATCH_SIZE):
|
||||
for j in range(i, FunctionalRnnTest._BATCH_SIZE):
|
||||
inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE])
|
||||
inputs[j][i] = inputs[i][j]
|
||||
|
||||
# Always leave one time slot empty, to check max_length behavior.
|
||||
sequence_length = np.random.randint(
|
||||
0,
|
||||
high=FunctionalRnnTest._BATCH_SIZE - 1,
|
||||
size=FunctionalRnnTest._BATCH_SIZE,
|
||||
dtype=np.int)
|
||||
return (inputs, sequence_length)
|
||||
|
||||
def _CreateRnnGraph(self,
|
||||
create_rnn_computation_func,
|
||||
cell,
|
||||
tf_inputs,
|
||||
tf_sequence_length,
|
||||
is_bidirectional,
|
||||
initial_state=None,
|
||||
time_major=None,
|
||||
scope=None):
|
||||
if is_bidirectional:
|
||||
tf_result = create_rnn_computation_func(
|
||||
cell_fw=cell,
|
||||
cell_bw=cell,
|
||||
inputs=tf_inputs,
|
||||
sequence_length=tf_sequence_length,
|
||||
dtype=dtypes.float32,
|
||||
time_major=time_major,
|
||||
scope=scope)
|
||||
else:
|
||||
tf_result = create_rnn_computation_func(
|
||||
cell=cell,
|
||||
inputs=tf_inputs,
|
||||
sequence_length=tf_sequence_length,
|
||||
initial_state=initial_state,
|
||||
dtype=dtypes.float32,
|
||||
time_major=time_major,
|
||||
scope=scope)
|
||||
grad = gradients_impl.gradients(tf_result, variables.trainable_variables())
|
||||
return {'inference': tf_result, 'grad': grad}
|
||||
|
||||
@ -102,15 +145,26 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
|
||||
variable_cache[n] = v
|
||||
|
||||
def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache,
|
||||
is_dynamic):
|
||||
is_dynamic, time_major=None, is_bidirectional=False):
|
||||
with ops.Graph().as_default() as graph:
|
||||
tf_inputs = array_ops.placeholder(
|
||||
dtypes.float32, shape=numpy_inputs.shape)
|
||||
tf_slen = array_ops.placeholder(dtypes.int32)
|
||||
feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen}
|
||||
cell = self._CreateCell(cell_name)
|
||||
fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn
|
||||
fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen)
|
||||
if is_dynamic:
|
||||
if is_bidirectional:
|
||||
fn = rnn_lib.bidirectional_dynamic_rnn
|
||||
else:
|
||||
fn = rnn_lib.dynamic_rnn
|
||||
else:
|
||||
if is_bidirectional:
|
||||
fn = functional_rnn.bidirectional_functional_rnn
|
||||
else:
|
||||
fn = functional_rnn.functional_rnn
|
||||
|
||||
fetches = self._CreateRnnGraph(
|
||||
fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major)
|
||||
with self.test_session(graph=graph) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Note that cell.trainable_variables it not always set.
|
||||
@ -158,6 +212,78 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
|
||||
self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
|
||||
|
||||
def testLstmWithTimeMajorInputs(self):
|
||||
"""Checks an LSTM against the reference implementation, with time_major."""
|
||||
time_major = True
|
||||
np_inputs, np_slen = self._CreateInputs(time_major=True)
|
||||
var_cache = {}
|
||||
args = [np_inputs, np_slen, 'lstm', var_cache]
|
||||
_, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major)
|
||||
_, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major)
|
||||
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
|
||||
self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
|
||||
|
||||
def testBidirectionalLstmWithTimeMajorInputs(self):
|
||||
"""Checks a bi-directional LSTM with time-major inputs."""
|
||||
time_major = True
|
||||
np_inputs, np_slen = self._CreateInputs(time_major)
|
||||
var_cache = {}
|
||||
args = [np_inputs, np_slen, 'lstm', var_cache]
|
||||
_, func_rnn = self._RunRnn(
|
||||
*(args + [False]), time_major=time_major, is_bidirectional=True)
|
||||
_, dyn_rnn = self._RunRnn(
|
||||
*(args + [True]), time_major=time_major, is_bidirectional=True)
|
||||
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
|
||||
# TODO(b/112170761): comment out this line after the bug is fixed.
|
||||
# self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
|
||||
|
||||
def testBidirectionalLstm(self):
|
||||
"""Checks time-major and batch-major rnn produce consistent results."""
|
||||
time_major_inputs, np_slen = self._CreateInputs(True)
|
||||
batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2])
|
||||
var_cache = {}
|
||||
args = [np_slen, 'lstm', var_cache, False]
|
||||
_, time_major_rnn = self._RunRnn(
|
||||
*([time_major_inputs] + args), time_major=True, is_bidirectional=True)
|
||||
_, batch_major_rnn = self._RunRnn(
|
||||
*([batch_major_inputs]+ args), time_major=False, is_bidirectional=True)
|
||||
# Convert the batch-major outputs to be time-major before the comparasion.
|
||||
outputs, state = batch_major_rnn['inference']
|
||||
outputs = [np.transpose(x, [1, 0, 2]) for x in outputs]
|
||||
batch_major_rnn['inference'] = [outputs, state]
|
||||
self.assertAllClose(time_major_rnn['inference'],
|
||||
batch_major_rnn['inference'])
|
||||
self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad'])
|
||||
|
||||
def testBidirectionalLstmWithSymmetricInputs(self):
|
||||
"""Checks a bi-directional LSTM with symmetric inputs.
|
||||
|
||||
time-major and batch-major rnn produce the same result with symmetric
|
||||
inputs.
|
||||
"""
|
||||
np_inputs, np_slen = self._CreateSymmetricInputs()
|
||||
var_cache = {}
|
||||
args = [np_inputs, np_slen, 'lstm', var_cache]
|
||||
_, time_major_func_rnn = self._RunRnn(
|
||||
*(args + [False]), time_major=True, is_bidirectional=True)
|
||||
_, batch_major_func_rnn = self._RunRnn(
|
||||
*(args + [False]), time_major=False, is_bidirectional=True)
|
||||
_, time_major_dyn_rnn = self._RunRnn(
|
||||
*(args + [True]), time_major=True, is_bidirectional=True)
|
||||
_, batch_major_dyn_rnn = self._RunRnn(
|
||||
*(args + [True]), time_major=False, is_bidirectional=True)
|
||||
self.assertAllClose(time_major_func_rnn['inference'],
|
||||
batch_major_func_rnn['inference'])
|
||||
self.assertAllClose(time_major_func_rnn['grad'],
|
||||
batch_major_func_rnn['grad'])
|
||||
self.assertAllClose(time_major_dyn_rnn['inference'],
|
||||
batch_major_dyn_rnn['inference'])
|
||||
self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad'])
|
||||
self.assertAllClose(time_major_func_rnn['inference'],
|
||||
batch_major_dyn_rnn['inference'])
|
||||
self.assertAllClose(time_major_func_rnn['grad'],
|
||||
batch_major_dyn_rnn['grad'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_lib.main()
|
||||
|
@ -284,8 +284,13 @@ def functional_rnn(cell, inputs, sequence_length=None,
|
||||
inputs=inputs,
|
||||
cell_fn=func_cell.cell_step,
|
||||
use_tpu=use_tpu)
|
||||
return _PostProcessOutput(extended_acc_state, extended_final_state,
|
||||
func_cell, inputs_flat[0].shape[0], sequence_length)
|
||||
tf_output, tf_state = _PostProcessOutput(
|
||||
extended_acc_state, extended_final_state, func_cell,
|
||||
inputs_flat[0].shape[0], sequence_length)
|
||||
|
||||
if time_major:
|
||||
tf_output = array_ops.transpose(tf_output, [1, 0, 2])
|
||||
return tf_output, tf_state
|
||||
|
||||
|
||||
def bidirectional_functional_rnn(
|
||||
|
Loading…
x
Reference in New Issue
Block a user