remove bidi rnn seq_len test, we didn't support seq_len anyway in tflite kernel anyway
PiperOrigin-RevId: 301737256 Change-Id: I6bf748c78abe99075aeb30ec20a1ca265ab6717c
This commit is contained in:
parent
6eca562622
commit
7fb2f93d40
@ -111,7 +111,6 @@ py_test(
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"notap", # b/141373014
|
||||
],
|
||||
deps = [
|
||||
":rnn",
|
||||
|
@ -297,33 +297,6 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
|
||||
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
|
||||
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
|
||||
|
||||
def testStaticRnnMultiRnnCellWithSequenceLength(self):
|
||||
sess = tf.compat.v1.Session()
|
||||
|
||||
x, prediction, output_class = self.buildModel(
|
||||
self.buildRnnLayer(),
|
||||
self.buildRnnLayer(),
|
||||
False,
|
||||
is_inference=False,
|
||||
use_sequence_length=True)
|
||||
self.trainModel(x, prediction, output_class, sess)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
x, prediction, output_class, new_sess = self.saveAndRestoreModel(
|
||||
self.buildRnnLayer(),
|
||||
self.buildRnnLayer(),
|
||||
sess,
|
||||
saver,
|
||||
False,
|
||||
use_sequence_length=True)
|
||||
|
||||
test_inputs, expected_output = self.getInferenceResult(
|
||||
x, output_class, new_sess)
|
||||
|
||||
# Test Toco-converted model.
|
||||
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
|
||||
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def testDynamicRnnMultiRnnCell(self):
|
||||
sess = tf.compat.v1.Session()
|
||||
@ -347,34 +320,6 @@ class BidirectionalSequenceRnnTest(test_util.TensorFlowTestCase):
|
||||
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
|
||||
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def testDynamicRnnMultiRnnCellWithSequenceLength(self):
|
||||
sess = tf.compat.v1.Session()
|
||||
|
||||
x, prediction, output_class = self.buildModel(
|
||||
self.buildRnnLayer(),
|
||||
self.buildRnnLayer(),
|
||||
True,
|
||||
is_inference=False,
|
||||
use_sequence_length=True)
|
||||
self.trainModel(x, prediction, output_class, sess)
|
||||
|
||||
saver = tf.compat.v1.train.Saver()
|
||||
x, prediction, output_class, new_sess = self.saveAndRestoreModel(
|
||||
self.buildRnnLayer(),
|
||||
self.buildRnnLayer(),
|
||||
sess,
|
||||
saver,
|
||||
is_dynamic_rnn=True,
|
||||
use_sequence_length=True)
|
||||
|
||||
test_inputs, expected_output = self.getInferenceResult(
|
||||
x, output_class, new_sess)
|
||||
|
||||
# Test Toco-converted model.
|
||||
result = self.tfliteInvoke(new_sess, test_inputs, x, output_class, False)
|
||||
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
|
Loading…
x
Reference in New Issue
Block a user