Generate example for basic lstm cell in tflite
PiperOrigin-RevId: 186656247
This commit is contained in:
parent
2e707494c4
commit
78916e7338
@ -34,6 +34,7 @@ gen_zipped_test_files(
|
||||
"l2norm.zip",
|
||||
"local_response_norm.zip",
|
||||
"log_softmax.zip",
|
||||
"lstm.zip",
|
||||
"max_pool.zip",
|
||||
"mean.zip",
|
||||
"mul.zip",
|
||||
|
@ -46,6 +46,7 @@ from google.protobuf import text_format
|
||||
# TODO(aselle): switch to TensorFlow's resource_loader
|
||||
from tensorflow.contrib.lite.testing import generate_examples_report as report_lib
|
||||
from tensorflow.python.framework import graph_util as tf_graph_util
|
||||
from tensorflow.python.ops import rnn
|
||||
|
||||
parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
|
||||
parser.add_argument("output_path",
|
||||
@ -108,11 +109,23 @@ KNOWN_BUGS = {
|
||||
}
|
||||
|
||||
|
||||
class ExtraTocoOptions(object):
|
||||
"""Additonal toco options besides input, output, shape."""
|
||||
|
||||
def __init__(self):
|
||||
# Whether to ignore control dependency nodes.
|
||||
self.drop_control_dependency = False
|
||||
# Allow custom ops in the toco conversion.
|
||||
self.allow_custom_ops = False
|
||||
# Rnn states that are used to support rnn / lstm cells.
|
||||
self.rnn_states = None
|
||||
|
||||
|
||||
def toco_options(data_types,
|
||||
input_arrays,
|
||||
output_arrays,
|
||||
shapes,
|
||||
drop_control_dependency):
|
||||
extra_toco_options=ExtraTocoOptions()):
|
||||
"""Create TOCO options to process a model.
|
||||
|
||||
Args:
|
||||
@ -120,8 +133,7 @@ def toco_options(data_types,
|
||||
input_arrays: names of the input tensors
|
||||
output_arrays: name of the output tensors
|
||||
shapes: shapes of the input tensors
|
||||
drop_control_dependency: whether to ignore control dependency nodes.
|
||||
|
||||
extra_toco_options: additional toco options
|
||||
Returns:
|
||||
the options in a string.
|
||||
"""
|
||||
@ -137,37 +149,15 @@ def toco_options(data_types,
|
||||
" --input_arrays=%s" % ",".join(input_arrays) +
|
||||
" --input_shapes=%s" % shape_str +
|
||||
" --output_arrays=%s" % ",".join(output_arrays))
|
||||
if drop_control_dependency:
|
||||
if extra_toco_options.drop_control_dependency:
|
||||
s += " --drop_control_dependency"
|
||||
if extra_toco_options.allow_custom_ops:
|
||||
s += " --allow_custom_ops"
|
||||
if extra_toco_options.rnn_states:
|
||||
s += (" --rnn_states='" + extra_toco_options.rnn_states + "'")
|
||||
return s
|
||||
|
||||
|
||||
def write_toco_options(filename,
|
||||
data_types,
|
||||
input_arrays,
|
||||
output_arrays,
|
||||
shapes,
|
||||
drop_control_dependency=False):
|
||||
"""Create TOCO options to process a model.
|
||||
|
||||
Args:
|
||||
filename: Filename to write the options to.
|
||||
data_types: input and inference types used by TOCO.
|
||||
input_arrays: names of the input tensors
|
||||
output_arrays: names of the output tensors
|
||||
shapes: shapes of the input tensors
|
||||
drop_control_dependency: whether to ignore control dependency nodes.
|
||||
"""
|
||||
with open(filename, "w") as fp:
|
||||
fp.write(
|
||||
toco_options(
|
||||
data_types=data_types,
|
||||
input_arrays=input_arrays,
|
||||
output_arrays=output_arrays,
|
||||
shapes=shapes,
|
||||
drop_control_dependency=drop_control_dependency))
|
||||
|
||||
|
||||
def write_examples(fp, examples):
|
||||
"""Given a list `examples`, write a text format representation.
|
||||
|
||||
@ -285,12 +275,14 @@ def make_control_dep_tests(zip_path):
|
||||
return [input_values], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_values])))
|
||||
|
||||
extra_toco_options = ExtraTocoOptions()
|
||||
extra_toco_options.drop_control_dependency = True
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs,
|
||||
drop_control_dependency=True)
|
||||
extra_toco_options)
|
||||
|
||||
|
||||
def toco_convert(graph_def_str, input_tensors, output_tensors,
|
||||
drop_control_dependency=False):
|
||||
extra_toco_options):
|
||||
"""Convert a model's graph def into a tflite model.
|
||||
|
||||
NOTE: this currently shells out to the toco binary, but we would like
|
||||
@ -298,9 +290,9 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
|
||||
|
||||
Args:
|
||||
graph_def_str: Graph def proto in serialized string format.
|
||||
input_tensors: List of input tensor tuples `(name, shape, type)`
|
||||
output_tensors: List of output tensors (names)
|
||||
drop_control_dependency: whether to ignore control dependency nodes.
|
||||
input_tensors: List of input tensor tuples `(name, shape, type)`.
|
||||
output_tensors: List of output tensors (names).
|
||||
extra_toco_options: Additional toco options.
|
||||
|
||||
Returns:
|
||||
output tflite model, log_txt from conversion
|
||||
@ -312,7 +304,7 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
|
||||
input_arrays=[x[0] for x in input_tensors],
|
||||
shapes=[x[1] for x in input_tensors],
|
||||
output_arrays=output_tensors,
|
||||
drop_control_dependency=drop_control_dependency)
|
||||
extra_toco_options=extra_toco_options)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as graphdef_file, \
|
||||
tempfile.NamedTemporaryFile() as output_file, \
|
||||
@ -341,7 +333,8 @@ def make_zip_of_tests(zip_path,
|
||||
test_parameters,
|
||||
make_graph,
|
||||
make_test_inputs,
|
||||
drop_control_dependency=False):
|
||||
extra_toco_options=ExtraTocoOptions(),
|
||||
use_frozen_graph=False):
|
||||
"""Helper to make a zip file of a bunch of TensorFlow models.
|
||||
|
||||
This does a cartestian product of the dictionary of test_parameters and
|
||||
@ -359,7 +352,9 @@ def make_zip_of_tests(zip_path,
|
||||
`[input1, input2, ...], [output1, output2, ...]`
|
||||
make_test_inputs: function taking `curr_params`, `session`, `input_tensors`,
|
||||
`output_tensors` and returns tuple `(input_values, output_values)`.
|
||||
drop_control_dependency: whether to ignore control dependency nodes.
|
||||
extra_toco_options: Additional toco options.
|
||||
use_frozen_graph: Whether or not freeze graph before toco converter.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if there are toco errors that can't be ignored.
|
||||
"""
|
||||
@ -419,21 +414,25 @@ def make_zip_of_tests(zip_path,
|
||||
return None, report
|
||||
report["toco"] = report_lib.FAILED
|
||||
report["tf"] = report_lib.SUCCESS
|
||||
|
||||
# Convert graph to toco
|
||||
input_tensors = [(input_tensor.name.split(":")[0],
|
||||
input_tensor.get_shape(), input_tensor.dtype)
|
||||
for input_tensor in inputs]
|
||||
output_tensors = [normalize_output_name(out.name) for out in outputs]
|
||||
graph_def = freeze_graph(
|
||||
sess,
|
||||
tf.global_variables() + inputs +
|
||||
outputs) if use_frozen_graph else sess.graph_def
|
||||
tflite_model_binary, toco_log = toco_convert(
|
||||
sess.graph_def.SerializeToString(),
|
||||
[(input_tensor.name.split(":")[0], input_tensor.get_shape(),
|
||||
input_tensor.dtype) for input_tensor in inputs],
|
||||
[normalize_output_name(out.name) for out in outputs],
|
||||
drop_control_dependency)
|
||||
graph_def.SerializeToString(), input_tensors, output_tensors,
|
||||
extra_toco_options)
|
||||
report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None
|
||||
else report_lib.FAILED)
|
||||
report["toco_log"] = toco_log
|
||||
|
||||
if FLAGS.save_graphdefs:
|
||||
archive.writestr(label + ".pb",
|
||||
text_format.MessageToString(sess.graph_def),
|
||||
text_format.MessageToString(graph_def),
|
||||
zipfile.ZIP_DEFLATED)
|
||||
|
||||
if tflite_model_binary:
|
||||
@ -1761,6 +1760,84 @@ def make_strided_slice_tests(zip_path):
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_lstm_tests(zip_path):
|
||||
"""Make a set of tests to do basic Lstm cell."""
|
||||
|
||||
test_parameters = [
|
||||
{
|
||||
"dtype": [tf.float32],
|
||||
"num_batchs": [1],
|
||||
"time_step_size": [1],
|
||||
"input_vec_size": [3],
|
||||
"num_cells": [4],
|
||||
},
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build a simple graph with BasicLSTMCell."""
|
||||
|
||||
num_batchs = parameters["num_batchs"]
|
||||
time_step_size = parameters["time_step_size"]
|
||||
input_vec_size = parameters["input_vec_size"]
|
||||
num_cells = parameters["num_cells"]
|
||||
inputs_after_split = []
|
||||
for i in xrange(time_step_size):
|
||||
one_timestamp_input = tf.placeholder(
|
||||
dtype=parameters["dtype"],
|
||||
name="split_{}".format(i),
|
||||
shape=[num_batchs, input_vec_size])
|
||||
inputs_after_split.append(one_timestamp_input)
|
||||
# Currently lstm identifier has a few limitations: only supports
|
||||
# forget_bias == 0, inner state activiation == tanh.
|
||||
# TODO(zhixianyan): Add another test with forget_bias == 1.
|
||||
# TODO(zhixianyan): Add another test with relu as activation.
|
||||
lstm_cell = tf.contrib.rnn.BasicLSTMCell(
|
||||
num_cells, forget_bias=0.0, state_is_tuple=True)
|
||||
cell_outputs, _ = rnn.static_rnn(
|
||||
lstm_cell, inputs_after_split, dtype=tf.float32)
|
||||
out = cell_outputs[-1]
|
||||
return inputs_after_split, [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
"""Feed inputs, assign vairables, and freeze graph."""
|
||||
|
||||
with tf.variable_scope("", reuse=True):
|
||||
kernel = tf.get_variable("rnn/basic_lstm_cell/kernel")
|
||||
bias = tf.get_variable("rnn/basic_lstm_cell/bias")
|
||||
kernel_values = create_tensor_data(
|
||||
parameters["dtype"], [kernel.shape[0], kernel.shape[1]], -1, 1)
|
||||
bias_values = create_tensor_data(parameters["dtype"], [bias.shape[0]], 0,
|
||||
1)
|
||||
sess.run(tf.group(kernel.assign(kernel_values), bias.assign(bias_values)))
|
||||
|
||||
num_batchs = parameters["num_batchs"]
|
||||
time_step_size = parameters["time_step_size"]
|
||||
input_vec_size = parameters["input_vec_size"]
|
||||
input_values = []
|
||||
for _ in xrange(time_step_size):
|
||||
tensor_data = create_tensor_data(parameters["dtype"],
|
||||
[num_batchs, input_vec_size], 0, 1)
|
||||
input_values.append(tensor_data)
|
||||
out = sess.run(outputs, feed_dict=dict(zip(inputs, input_values)))
|
||||
return input_values, out
|
||||
|
||||
# TODO(zhixianyan): Automatically generate rnn_states for lstm cell.
|
||||
extra_toco_options = ExtraTocoOptions()
|
||||
extra_toco_options.rnn_states = (
|
||||
"{state_array:rnn/BasicLSTMCellZeroState/zeros,"
|
||||
"back_edge_source_array:rnn/basic_lstm_cell/Add_1,size:4},"
|
||||
"{state_array:rnn/BasicLSTMCellZeroState/zeros_1,"
|
||||
"back_edge_source_array:rnn/basic_lstm_cell/Mul_2,size:4}")
|
||||
|
||||
make_zip_of_tests(
|
||||
zip_path,
|
||||
test_parameters,
|
||||
build_graph,
|
||||
build_inputs,
|
||||
extra_toco_options,
|
||||
use_frozen_graph=True)
|
||||
|
||||
|
||||
def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
|
||||
"""Given an input perform a sequence of TensorFlow ops to produce l2pool."""
|
||||
return tf.sqrt(tf.nn.avg_pool(
|
||||
@ -1850,6 +1927,7 @@ def main(unused_args):
|
||||
"strided_slice.zip": make_strided_slice_tests,
|
||||
"exp.zip": make_exp_tests,
|
||||
"log_softmax.zip": make_log_softmax_tests,
|
||||
"lstm.zip": make_lstm_tests,
|
||||
}
|
||||
out = FLAGS.zip_to_output
|
||||
bin_path = FLAGS.toco
|
||||
|
@ -266,6 +266,7 @@ INSTANTIATE_TESTS(sub)
|
||||
INSTANTIATE_TESTS(split)
|
||||
INSTANTIATE_TESTS(div)
|
||||
INSTANTIATE_TESTS(transpose)
|
||||
INSTANTIATE_TESTS(lstm)
|
||||
INSTANTIATE_TESTS(mean)
|
||||
INSTANTIATE_TESTS(squeeze)
|
||||
INSTANTIATE_TESTS(strided_slice)
|
||||
|
@ -192,27 +192,25 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
|
||||
int model_outputs = interpreter->outputs().size();
|
||||
TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size());
|
||||
for (size_t i = 0; i < interpreter->outputs().size(); i++) {
|
||||
bool tensors_differ = false;
|
||||
int output_index = interpreter->outputs()[i];
|
||||
if (const float* data = interpreter->typed_tensor<float>(output_index)) {
|
||||
for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
|
||||
float computed = data[idx];
|
||||
float reference = example.outputs[0].flat_data[idx];
|
||||
float diff = std::abs(computed - reference);
|
||||
bool error_is_large = false;
|
||||
// For very small numbers, try absolute error, otherwise go with
|
||||
// relative.
|
||||
if (std::abs(reference) < kRelativeThreshold) {
|
||||
error_is_large = (diff > kAbsoluteThreshold);
|
||||
} else {
|
||||
error_is_large = (diff > kRelativeThreshold * std::abs(reference));
|
||||
}
|
||||
if (error_is_large) {
|
||||
bool local_tensors_differ =
|
||||
std::abs(reference) < kRelativeThreshold
|
||||
? diff > kAbsoluteThreshold
|
||||
: diff > kRelativeThreshold * std::abs(reference);
|
||||
if (local_tensors_differ) {
|
||||
fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n",
|
||||
i, idx, data[idx], reference);
|
||||
return kTfLiteError;
|
||||
tensors_differ = local_tensors_differ;
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
} else if (const int32_t* data =
|
||||
interpreter->typed_tensor<int32_t>(output_index)) {
|
||||
for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
|
||||
@ -221,10 +219,9 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
|
||||
if (std::abs(computed - reference) > 0) {
|
||||
fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %d\n",
|
||||
i, idx, computed, reference);
|
||||
return kTfLiteError;
|
||||
tensors_differ = true;
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
} else if (const int64_t* data =
|
||||
interpreter->typed_tensor<int64_t>(output_index)) {
|
||||
for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
|
||||
@ -235,14 +232,15 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
|
||||
"output[%zu][%zu] did not match %" PRId64
|
||||
" vs reference %" PRId64 "\n",
|
||||
i, idx, computed, reference);
|
||||
return kTfLiteError;
|
||||
tensors_differ = true;
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
} else {
|
||||
fprintf(stderr, "output[%zu] was not float or int data\n", i);
|
||||
return kTfLiteError;
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
if (tensors_differ) return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user