Fix the output shape of functional_rnn for time-major inputs.
PiperOrigin-RevId: 207780606
This commit is contained in:
parent
cf233f7281
commit
9b84c91b68
@ -61,10 +61,17 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
|
|||||||
func, args = self._CELLDEFS[celldef_name]
|
func, args = self._CELLDEFS[celldef_name]
|
||||||
return func(*args)
|
return func(*args)
|
||||||
|
|
||||||
def _CreateInputs(self):
|
def _CreateInputs(self, time_major=False):
|
||||||
inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE,
|
if time_major:
|
||||||
FunctionalRnnTest._TOTAL_TIME,
|
inputs = np.random.random([
|
||||||
FunctionalRnnTest._INPUT_SIZE])
|
FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE,
|
||||||
|
FunctionalRnnTest._INPUT_SIZE
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
inputs = np.random.random([
|
||||||
|
FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME,
|
||||||
|
FunctionalRnnTest._INPUT_SIZE
|
||||||
|
])
|
||||||
# Always leave one time slot empty, to check max_length behavior.
|
# Always leave one time slot empty, to check max_length behavior.
|
||||||
sequence_length = np.random.randint(
|
sequence_length = np.random.randint(
|
||||||
0, high=FunctionalRnnTest._TOTAL_TIME - 1,
|
0, high=FunctionalRnnTest._TOTAL_TIME - 1,
|
||||||
@ -72,15 +79,51 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
|
|||||||
dtype=np.int)
|
dtype=np.int)
|
||||||
return (inputs, sequence_length)
|
return (inputs, sequence_length)
|
||||||
|
|
||||||
def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs,
|
def _CreateSymmetricInputs(self):
|
||||||
tf_sequence_length, initial_state=None,
|
# total time = batch size
|
||||||
time_major=None, scope=None):
|
inputs = np.zeros(
|
||||||
tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs,
|
(FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE,
|
||||||
sequence_length=tf_sequence_length,
|
FunctionalRnnTest._INPUT_SIZE))
|
||||||
initial_state=initial_state,
|
for i in range(FunctionalRnnTest._BATCH_SIZE):
|
||||||
dtype=dtypes.float32,
|
for j in range(i, FunctionalRnnTest._BATCH_SIZE):
|
||||||
time_major=time_major,
|
inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE])
|
||||||
scope=scope)
|
inputs[j][i] = inputs[i][j]
|
||||||
|
|
||||||
|
# Always leave one time slot empty, to check max_length behavior.
|
||||||
|
sequence_length = np.random.randint(
|
||||||
|
0,
|
||||||
|
high=FunctionalRnnTest._BATCH_SIZE - 1,
|
||||||
|
size=FunctionalRnnTest._BATCH_SIZE,
|
||||||
|
dtype=np.int)
|
||||||
|
return (inputs, sequence_length)
|
||||||
|
|
||||||
|
def _CreateRnnGraph(self,
|
||||||
|
create_rnn_computation_func,
|
||||||
|
cell,
|
||||||
|
tf_inputs,
|
||||||
|
tf_sequence_length,
|
||||||
|
is_bidirectional,
|
||||||
|
initial_state=None,
|
||||||
|
time_major=None,
|
||||||
|
scope=None):
|
||||||
|
if is_bidirectional:
|
||||||
|
tf_result = create_rnn_computation_func(
|
||||||
|
cell_fw=cell,
|
||||||
|
cell_bw=cell,
|
||||||
|
inputs=tf_inputs,
|
||||||
|
sequence_length=tf_sequence_length,
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
time_major=time_major,
|
||||||
|
scope=scope)
|
||||||
|
else:
|
||||||
|
tf_result = create_rnn_computation_func(
|
||||||
|
cell=cell,
|
||||||
|
inputs=tf_inputs,
|
||||||
|
sequence_length=tf_sequence_length,
|
||||||
|
initial_state=initial_state,
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
time_major=time_major,
|
||||||
|
scope=scope)
|
||||||
grad = gradients_impl.gradients(tf_result, variables.trainable_variables())
|
grad = gradients_impl.gradients(tf_result, variables.trainable_variables())
|
||||||
return {'inference': tf_result, 'grad': grad}
|
return {'inference': tf_result, 'grad': grad}
|
||||||
|
|
||||||
@ -102,15 +145,26 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
|
|||||||
variable_cache[n] = v
|
variable_cache[n] = v
|
||||||
|
|
||||||
def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache,
|
def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache,
|
||||||
is_dynamic):
|
is_dynamic, time_major=None, is_bidirectional=False):
|
||||||
with ops.Graph().as_default() as graph:
|
with ops.Graph().as_default() as graph:
|
||||||
tf_inputs = array_ops.placeholder(
|
tf_inputs = array_ops.placeholder(
|
||||||
dtypes.float32, shape=numpy_inputs.shape)
|
dtypes.float32, shape=numpy_inputs.shape)
|
||||||
tf_slen = array_ops.placeholder(dtypes.int32)
|
tf_slen = array_ops.placeholder(dtypes.int32)
|
||||||
feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen}
|
feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen}
|
||||||
cell = self._CreateCell(cell_name)
|
cell = self._CreateCell(cell_name)
|
||||||
fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn
|
if is_dynamic:
|
||||||
fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen)
|
if is_bidirectional:
|
||||||
|
fn = rnn_lib.bidirectional_dynamic_rnn
|
||||||
|
else:
|
||||||
|
fn = rnn_lib.dynamic_rnn
|
||||||
|
else:
|
||||||
|
if is_bidirectional:
|
||||||
|
fn = functional_rnn.bidirectional_functional_rnn
|
||||||
|
else:
|
||||||
|
fn = functional_rnn.functional_rnn
|
||||||
|
|
||||||
|
fetches = self._CreateRnnGraph(
|
||||||
|
fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major)
|
||||||
with self.test_session(graph=graph) as sess:
|
with self.test_session(graph=graph) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
# Note that cell.trainable_variables it not always set.
|
# Note that cell.trainable_variables it not always set.
|
||||||
@ -158,6 +212,78 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
|
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
|
||||||
self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
|
self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
|
||||||
|
|
||||||
|
def testLstmWithTimeMajorInputs(self):
|
||||||
|
"""Checks an LSTM against the reference implementation, with time_major."""
|
||||||
|
time_major = True
|
||||||
|
np_inputs, np_slen = self._CreateInputs(time_major=True)
|
||||||
|
var_cache = {}
|
||||||
|
args = [np_inputs, np_slen, 'lstm', var_cache]
|
||||||
|
_, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major)
|
||||||
|
_, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major)
|
||||||
|
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
|
||||||
|
self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
|
||||||
|
|
||||||
|
def testBidirectionalLstmWithTimeMajorInputs(self):
|
||||||
|
"""Checks a bi-directional LSTM with time-major inputs."""
|
||||||
|
time_major = True
|
||||||
|
np_inputs, np_slen = self._CreateInputs(time_major)
|
||||||
|
var_cache = {}
|
||||||
|
args = [np_inputs, np_slen, 'lstm', var_cache]
|
||||||
|
_, func_rnn = self._RunRnn(
|
||||||
|
*(args + [False]), time_major=time_major, is_bidirectional=True)
|
||||||
|
_, dyn_rnn = self._RunRnn(
|
||||||
|
*(args + [True]), time_major=time_major, is_bidirectional=True)
|
||||||
|
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
|
||||||
|
# TODO(b/112170761): comment out this line after the bug is fixed.
|
||||||
|
# self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
|
||||||
|
|
||||||
|
def testBidirectionalLstm(self):
|
||||||
|
"""Checks time-major and batch-major rnn produce consistent results."""
|
||||||
|
time_major_inputs, np_slen = self._CreateInputs(True)
|
||||||
|
batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2])
|
||||||
|
var_cache = {}
|
||||||
|
args = [np_slen, 'lstm', var_cache, False]
|
||||||
|
_, time_major_rnn = self._RunRnn(
|
||||||
|
*([time_major_inputs] + args), time_major=True, is_bidirectional=True)
|
||||||
|
_, batch_major_rnn = self._RunRnn(
|
||||||
|
*([batch_major_inputs]+ args), time_major=False, is_bidirectional=True)
|
||||||
|
# Convert the batch-major outputs to be time-major before the comparasion.
|
||||||
|
outputs, state = batch_major_rnn['inference']
|
||||||
|
outputs = [np.transpose(x, [1, 0, 2]) for x in outputs]
|
||||||
|
batch_major_rnn['inference'] = [outputs, state]
|
||||||
|
self.assertAllClose(time_major_rnn['inference'],
|
||||||
|
batch_major_rnn['inference'])
|
||||||
|
self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad'])
|
||||||
|
|
||||||
|
def testBidirectionalLstmWithSymmetricInputs(self):
|
||||||
|
"""Checks a bi-directional LSTM with symmetric inputs.
|
||||||
|
|
||||||
|
time-major and batch-major rnn produce the same result with symmetric
|
||||||
|
inputs.
|
||||||
|
"""
|
||||||
|
np_inputs, np_slen = self._CreateSymmetricInputs()
|
||||||
|
var_cache = {}
|
||||||
|
args = [np_inputs, np_slen, 'lstm', var_cache]
|
||||||
|
_, time_major_func_rnn = self._RunRnn(
|
||||||
|
*(args + [False]), time_major=True, is_bidirectional=True)
|
||||||
|
_, batch_major_func_rnn = self._RunRnn(
|
||||||
|
*(args + [False]), time_major=False, is_bidirectional=True)
|
||||||
|
_, time_major_dyn_rnn = self._RunRnn(
|
||||||
|
*(args + [True]), time_major=True, is_bidirectional=True)
|
||||||
|
_, batch_major_dyn_rnn = self._RunRnn(
|
||||||
|
*(args + [True]), time_major=False, is_bidirectional=True)
|
||||||
|
self.assertAllClose(time_major_func_rnn['inference'],
|
||||||
|
batch_major_func_rnn['inference'])
|
||||||
|
self.assertAllClose(time_major_func_rnn['grad'],
|
||||||
|
batch_major_func_rnn['grad'])
|
||||||
|
self.assertAllClose(time_major_dyn_rnn['inference'],
|
||||||
|
batch_major_dyn_rnn['inference'])
|
||||||
|
self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad'])
|
||||||
|
self.assertAllClose(time_major_func_rnn['inference'],
|
||||||
|
batch_major_dyn_rnn['inference'])
|
||||||
|
self.assertAllClose(time_major_func_rnn['grad'],
|
||||||
|
batch_major_dyn_rnn['grad'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_lib.main()
|
test_lib.main()
|
||||||
|
@ -284,8 +284,13 @@ def functional_rnn(cell, inputs, sequence_length=None,
|
|||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
cell_fn=func_cell.cell_step,
|
cell_fn=func_cell.cell_step,
|
||||||
use_tpu=use_tpu)
|
use_tpu=use_tpu)
|
||||||
return _PostProcessOutput(extended_acc_state, extended_final_state,
|
tf_output, tf_state = _PostProcessOutput(
|
||||||
func_cell, inputs_flat[0].shape[0], sequence_length)
|
extended_acc_state, extended_final_state, func_cell,
|
||||||
|
inputs_flat[0].shape[0], sequence_length)
|
||||||
|
|
||||||
|
if time_major:
|
||||||
|
tf_output = array_ops.transpose(tf_output, [1, 0, 2])
|
||||||
|
return tf_output, tf_state
|
||||||
|
|
||||||
|
|
||||||
def bidirectional_functional_rnn(
|
def bidirectional_functional_rnn(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user