Add unidirectional_sequence_lstm to generated_examples tests

PiperOrigin-RevId: 236984026
This commit is contained in:
A. Unique TensorFlower 2019-03-05 22:33:29 -08:00 committed by TensorFlower Gardener
parent dc3900a969
commit afab5b322f
2 changed files with 84 additions and 0 deletions

View File

@ -325,6 +325,7 @@ def generated_test_models():
"topk",
"transpose",
"transpose_conv",
"unidirectional_sequence_lstm",
"unique",
"unpack",
"unroll_batch_matmul",
@ -340,6 +341,7 @@ def generated_test_models_failing(conversion_mode):
return [
"lstm", # TODO(b/117510976): Restore when lstm flex conversion works.
"unroll_batch_matmul", # TODO(b/123030774): Fails in 1.13 tests.
"unidirectional_sequence_lstm",
]
return []

View File

@ -54,6 +54,7 @@ from google.protobuf import text_format
# TODO(aselle): switch to TensorFlow's resource_loader
from tensorflow.lite.testing import generate_examples_report as report_lib
from tensorflow.lite.testing import string_util_wrapper
from tensorflow.python.framework import test_util
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.ops import rnn
@ -504,6 +505,10 @@ def make_zip_of_tests(zip_path,
extra_toco_options.split_tflite_lstm_inputs = param_dict_real[
"split_tflite_lstm_inputs"]
# Convert ophint ops if presented.
graph_def = tf.lite.experimental.convert_op_hints_to_stubs(
graph_def=graph_def)
graph_def = tf.graph_util.remove_training_nodes(graph_def)
tflite_model_binary, toco_log = toco_convert(
graph_def.SerializeToString(), input_tensors, output_tensors,
extra_toco_options)
@ -4357,6 +4362,83 @@ def make_reverse_sequence_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@test_util.enable_control_flow_v2
def make_unidirectional_sequence_lstm_tests(zip_path):
"""Make a set of tests to do unidirectional_sequence_lstm."""
test_parameters = [{
"batch_size": [2, 4, 6],
"seq_length": [1, 3],
"units": [4, 5],
"use_peepholes": [False, True],
"is_dynamic_rnn": [False, True]
}]
def build_graph(parameters):
input_values = []
if parameters["is_dynamic_rnn"]:
shape = [
parameters["seq_length"], parameters["batch_size"],
parameters["units"]
]
input_value = tf.placeholder(dtype=tf.float32, name="input", shape=shape)
input_values.append(input_value)
lstm_cell = tf.lite.experimental.nn.TFLiteLSTMCell(
parameters["units"],
use_peepholes=parameters["use_peepholes"])
outs, _ = tf.lite.experimental.nn.dynamic_rnn(
lstm_cell, input_value, dtype=tf.float32, time_major=True)
outs = tf.unstack(outs, axis=1)
else:
shape = [parameters["batch_size"], parameters["units"]]
for i in range(parameters["seq_length"]):
input_value = tf.placeholder(
dtype=tf.float32, name=("input_%d" % i), shape=shape)
input_values.append(input_value)
lstm_cell = tf.lite.experimental.nn.TFLiteLSTMCell(
parameters["units"], use_peepholes=parameters["use_peepholes"])
outs, _ = tf.nn.static_rnn(lstm_cell, input_values, dtype=tf.float32)
real_output = tf.zeros([1], dtype=tf.float32) + outs[-1]
return input_values, [real_output]
def build_inputs(parameters, sess, inputs, outputs):
input_values = []
if parameters["is_dynamic_rnn"]:
shape = [
parameters["seq_length"], parameters["batch_size"],
parameters["units"]
]
input_value = create_tensor_data(tf.float32, shape)
input_values.append(input_value)
else:
shape = [parameters["batch_size"], parameters["units"]]
for i in range(parameters["seq_length"]):
input_value = create_tensor_data(tf.float32, shape)
input_values.append(input_value)
init = tf.global_variables_initializer()
sess.run(init)
# Tflite fused kernel takes input as [time, batch, input].
# For static unidirectional sequence lstm, the input is an array sized of
# time, and pack the array together, however, for time = 1, the input is
# not packed.
tflite_input_values = input_values
if not parameters["is_dynamic_rnn"] and parameters["seq_length"] == 1:
tflite_input_values = [
input_values[0].reshape((1, parameters["batch_size"],
parameters["units"]))
]
return tflite_input_values, sess.run(
outputs, feed_dict=dict(zip(inputs, input_values)))
make_zip_of_tests(
zip_path,
test_parameters,
build_graph,
build_inputs,
use_frozen_graph=True)
# Toco binary path provided by the generate rule.
bin_path = None