Add unidirectional_sequence_lstm to generated_examples tests
PiperOrigin-RevId: 236984026
This commit is contained in:
parent
dc3900a969
commit
afab5b322f
@ -325,6 +325,7 @@ def generated_test_models():
|
||||
"topk",
|
||||
"transpose",
|
||||
"transpose_conv",
|
||||
"unidirectional_sequence_lstm",
|
||||
"unique",
|
||||
"unpack",
|
||||
"unroll_batch_matmul",
|
||||
@ -340,6 +341,7 @@ def generated_test_models_failing(conversion_mode):
|
||||
return [
|
||||
"lstm", # TODO(b/117510976): Restore when lstm flex conversion works.
|
||||
"unroll_batch_matmul", # TODO(b/123030774): Fails in 1.13 tests.
|
||||
"unidirectional_sequence_lstm",
|
||||
]
|
||||
|
||||
return []
|
||||
|
@ -54,6 +54,7 @@ from google.protobuf import text_format
|
||||
# TODO(aselle): switch to TensorFlow's resource_loader
|
||||
from tensorflow.lite.testing import generate_examples_report as report_lib
|
||||
from tensorflow.lite.testing import string_util_wrapper
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import graph_util as tf_graph_util
|
||||
from tensorflow.python.ops import rnn
|
||||
|
||||
@ -504,6 +505,10 @@ def make_zip_of_tests(zip_path,
|
||||
extra_toco_options.split_tflite_lstm_inputs = param_dict_real[
|
||||
"split_tflite_lstm_inputs"]
|
||||
|
||||
# Convert ophint ops if presented.
|
||||
graph_def = tf.lite.experimental.convert_op_hints_to_stubs(
|
||||
graph_def=graph_def)
|
||||
graph_def = tf.graph_util.remove_training_nodes(graph_def)
|
||||
tflite_model_binary, toco_log = toco_convert(
|
||||
graph_def.SerializeToString(), input_tensors, output_tensors,
|
||||
extra_toco_options)
|
||||
@ -4357,6 +4362,83 @@ def make_reverse_sequence_tests(zip_path):
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def make_unidirectional_sequence_lstm_tests(zip_path):
|
||||
"""Make a set of tests to do unidirectional_sequence_lstm."""
|
||||
|
||||
test_parameters = [{
|
||||
"batch_size": [2, 4, 6],
|
||||
"seq_length": [1, 3],
|
||||
"units": [4, 5],
|
||||
"use_peepholes": [False, True],
|
||||
"is_dynamic_rnn": [False, True]
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
input_values = []
|
||||
if parameters["is_dynamic_rnn"]:
|
||||
shape = [
|
||||
parameters["seq_length"], parameters["batch_size"],
|
||||
parameters["units"]
|
||||
]
|
||||
input_value = tf.placeholder(dtype=tf.float32, name="input", shape=shape)
|
||||
input_values.append(input_value)
|
||||
lstm_cell = tf.lite.experimental.nn.TFLiteLSTMCell(
|
||||
parameters["units"],
|
||||
use_peepholes=parameters["use_peepholes"])
|
||||
outs, _ = tf.lite.experimental.nn.dynamic_rnn(
|
||||
lstm_cell, input_value, dtype=tf.float32, time_major=True)
|
||||
outs = tf.unstack(outs, axis=1)
|
||||
else:
|
||||
shape = [parameters["batch_size"], parameters["units"]]
|
||||
for i in range(parameters["seq_length"]):
|
||||
input_value = tf.placeholder(
|
||||
dtype=tf.float32, name=("input_%d" % i), shape=shape)
|
||||
input_values.append(input_value)
|
||||
lstm_cell = tf.lite.experimental.nn.TFLiteLSTMCell(
|
||||
parameters["units"], use_peepholes=parameters["use_peepholes"])
|
||||
outs, _ = tf.nn.static_rnn(lstm_cell, input_values, dtype=tf.float32)
|
||||
|
||||
real_output = tf.zeros([1], dtype=tf.float32) + outs[-1]
|
||||
return input_values, [real_output]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_values = []
|
||||
if parameters["is_dynamic_rnn"]:
|
||||
shape = [
|
||||
parameters["seq_length"], parameters["batch_size"],
|
||||
parameters["units"]
|
||||
]
|
||||
input_value = create_tensor_data(tf.float32, shape)
|
||||
input_values.append(input_value)
|
||||
else:
|
||||
shape = [parameters["batch_size"], parameters["units"]]
|
||||
for i in range(parameters["seq_length"]):
|
||||
input_value = create_tensor_data(tf.float32, shape)
|
||||
input_values.append(input_value)
|
||||
init = tf.global_variables_initializer()
|
||||
sess.run(init)
|
||||
# Tflite fused kernel takes input as [time, batch, input].
|
||||
# For static unidirectional sequence lstm, the input is an array sized of
|
||||
# time, and pack the array together, however, for time = 1, the input is
|
||||
# not packed.
|
||||
tflite_input_values = input_values
|
||||
if not parameters["is_dynamic_rnn"] and parameters["seq_length"] == 1:
|
||||
tflite_input_values = [
|
||||
input_values[0].reshape((1, parameters["batch_size"],
|
||||
parameters["units"]))
|
||||
]
|
||||
return tflite_input_values, sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, input_values)))
|
||||
|
||||
make_zip_of_tests(
|
||||
zip_path,
|
||||
test_parameters,
|
||||
build_graph,
|
||||
build_inputs,
|
||||
use_frozen_graph=True)
|
||||
|
||||
|
||||
# Toco binary path provided by the generate rule.
|
||||
bin_path = None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user