diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 1ccf7d4d0e7..b5960d6f8d9 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -34,6 +34,7 @@ gen_zipped_test_files( "l2norm.zip", "local_response_norm.zip", "log_softmax.zip", + "lstm.zip", "max_pool.zip", "mean.zip", "mul.zip", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 2cbac7caa65..2481add7691 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -46,6 +46,7 @@ from google.protobuf import text_format # TODO(aselle): switch to TensorFlow's resource_loader from tensorflow.contrib.lite.testing import generate_examples_report as report_lib from tensorflow.python.framework import graph_util as tf_graph_util +from tensorflow.python.ops import rnn parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") parser.add_argument("output_path", @@ -108,11 +109,23 @@ KNOWN_BUGS = { } +class ExtraTocoOptions(object): + """Additonal toco options besides input, output, shape.""" + + def __init__(self): + # Whether to ignore control dependency nodes. + self.drop_control_dependency = False + # Allow custom ops in the toco conversion. + self.allow_custom_ops = False + # Rnn states that are used to support rnn / lstm cells. + self.rnn_states = None + + def toco_options(data_types, input_arrays, output_arrays, shapes, - drop_control_dependency): + extra_toco_options=ExtraTocoOptions()): """Create TOCO options to process a model. Args: @@ -120,8 +133,7 @@ def toco_options(data_types, input_arrays: names of the input tensors output_arrays: name of the output tensors shapes: shapes of the input tensors - drop_control_dependency: whether to ignore control dependency nodes. - + extra_toco_options: additional toco options Returns: the options in a string. """ @@ -137,37 +149,15 @@ def toco_options(data_types, " --input_arrays=%s" % ",".join(input_arrays) + " --input_shapes=%s" % shape_str + " --output_arrays=%s" % ",".join(output_arrays)) - if drop_control_dependency: + if extra_toco_options.drop_control_dependency: s += " --drop_control_dependency" + if extra_toco_options.allow_custom_ops: + s += " --allow_custom_ops" + if extra_toco_options.rnn_states: + s += (" --rnn_states='" + extra_toco_options.rnn_states + "'") return s -def write_toco_options(filename, - data_types, - input_arrays, - output_arrays, - shapes, - drop_control_dependency=False): - """Create TOCO options to process a model. - - Args: - filename: Filename to write the options to. - data_types: input and inference types used by TOCO. - input_arrays: names of the input tensors - output_arrays: names of the output tensors - shapes: shapes of the input tensors - drop_control_dependency: whether to ignore control dependency nodes. - """ - with open(filename, "w") as fp: - fp.write( - toco_options( - data_types=data_types, - input_arrays=input_arrays, - output_arrays=output_arrays, - shapes=shapes, - drop_control_dependency=drop_control_dependency)) - - def write_examples(fp, examples): """Given a list `examples`, write a text format representation. @@ -285,12 +275,14 @@ def make_control_dep_tests(zip_path): return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) + extra_toco_options = ExtraTocoOptions() + extra_toco_options.drop_control_dependency = True make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs, - drop_control_dependency=True) + extra_toco_options) def toco_convert(graph_def_str, input_tensors, output_tensors, - drop_control_dependency=False): + extra_toco_options): """Convert a model's graph def into a tflite model. NOTE: this currently shells out to the toco binary, but we would like @@ -298,9 +290,9 @@ def toco_convert(graph_def_str, input_tensors, output_tensors, Args: graph_def_str: Graph def proto in serialized string format. - input_tensors: List of input tensor tuples `(name, shape, type)` - output_tensors: List of output tensors (names) - drop_control_dependency: whether to ignore control dependency nodes. + input_tensors: List of input tensor tuples `(name, shape, type)`. + output_tensors: List of output tensors (names). + extra_toco_options: Additional toco options. Returns: output tflite model, log_txt from conversion @@ -312,7 +304,7 @@ def toco_convert(graph_def_str, input_tensors, output_tensors, input_arrays=[x[0] for x in input_tensors], shapes=[x[1] for x in input_tensors], output_arrays=output_tensors, - drop_control_dependency=drop_control_dependency) + extra_toco_options=extra_toco_options) with tempfile.NamedTemporaryFile() as graphdef_file, \ tempfile.NamedTemporaryFile() as output_file, \ @@ -341,7 +333,8 @@ def make_zip_of_tests(zip_path, test_parameters, make_graph, make_test_inputs, - drop_control_dependency=False): + extra_toco_options=ExtraTocoOptions(), + use_frozen_graph=False): """Helper to make a zip file of a bunch of TensorFlow models. This does a cartestian product of the dictionary of test_parameters and @@ -359,7 +352,9 @@ def make_zip_of_tests(zip_path, `[input1, input2, ...], [output1, output2, ...]` make_test_inputs: function taking `curr_params`, `session`, `input_tensors`, `output_tensors` and returns tuple `(input_values, output_values)`. - drop_control_dependency: whether to ignore control dependency nodes. + extra_toco_options: Additional toco options. + use_frozen_graph: Whether or not freeze graph before toco converter. + Raises: RuntimeError: if there are toco errors that can't be ignored. """ @@ -419,21 +414,25 @@ def make_zip_of_tests(zip_path, return None, report report["toco"] = report_lib.FAILED report["tf"] = report_lib.SUCCESS - # Convert graph to toco + input_tensors = [(input_tensor.name.split(":")[0], + input_tensor.get_shape(), input_tensor.dtype) + for input_tensor in inputs] + output_tensors = [normalize_output_name(out.name) for out in outputs] + graph_def = freeze_graph( + sess, + tf.global_variables() + inputs + + outputs) if use_frozen_graph else sess.graph_def tflite_model_binary, toco_log = toco_convert( - sess.graph_def.SerializeToString(), - [(input_tensor.name.split(":")[0], input_tensor.get_shape(), - input_tensor.dtype) for input_tensor in inputs], - [normalize_output_name(out.name) for out in outputs], - drop_control_dependency) + graph_def.SerializeToString(), input_tensors, output_tensors, + extra_toco_options) report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None else report_lib.FAILED) report["toco_log"] = toco_log if FLAGS.save_graphdefs: archive.writestr(label + ".pb", - text_format.MessageToString(sess.graph_def), + text_format.MessageToString(graph_def), zipfile.ZIP_DEFLATED) if tflite_model_binary: @@ -1761,6 +1760,84 @@ def make_strided_slice_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_lstm_tests(zip_path): + """Make a set of tests to do basic Lstm cell.""" + + test_parameters = [ + { + "dtype": [tf.float32], + "num_batchs": [1], + "time_step_size": [1], + "input_vec_size": [3], + "num_cells": [4], + }, + ] + + def build_graph(parameters): + """Build a simple graph with BasicLSTMCell.""" + + num_batchs = parameters["num_batchs"] + time_step_size = parameters["time_step_size"] + input_vec_size = parameters["input_vec_size"] + num_cells = parameters["num_cells"] + inputs_after_split = [] + for i in xrange(time_step_size): + one_timestamp_input = tf.placeholder( + dtype=parameters["dtype"], + name="split_{}".format(i), + shape=[num_batchs, input_vec_size]) + inputs_after_split.append(one_timestamp_input) + # Currently lstm identifier has a few limitations: only supports + # forget_bias == 0, inner state activiation == tanh. + # TODO(zhixianyan): Add another test with forget_bias == 1. + # TODO(zhixianyan): Add another test with relu as activation. + lstm_cell = tf.contrib.rnn.BasicLSTMCell( + num_cells, forget_bias=0.0, state_is_tuple=True) + cell_outputs, _ = rnn.static_rnn( + lstm_cell, inputs_after_split, dtype=tf.float32) + out = cell_outputs[-1] + return inputs_after_split, [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Feed inputs, assign vairables, and freeze graph.""" + + with tf.variable_scope("", reuse=True): + kernel = tf.get_variable("rnn/basic_lstm_cell/kernel") + bias = tf.get_variable("rnn/basic_lstm_cell/bias") + kernel_values = create_tensor_data( + parameters["dtype"], [kernel.shape[0], kernel.shape[1]], -1, 1) + bias_values = create_tensor_data(parameters["dtype"], [bias.shape[0]], 0, + 1) + sess.run(tf.group(kernel.assign(kernel_values), bias.assign(bias_values))) + + num_batchs = parameters["num_batchs"] + time_step_size = parameters["time_step_size"] + input_vec_size = parameters["input_vec_size"] + input_values = [] + for _ in xrange(time_step_size): + tensor_data = create_tensor_data(parameters["dtype"], + [num_batchs, input_vec_size], 0, 1) + input_values.append(tensor_data) + out = sess.run(outputs, feed_dict=dict(zip(inputs, input_values))) + return input_values, out + + # TODO(zhixianyan): Automatically generate rnn_states for lstm cell. + extra_toco_options = ExtraTocoOptions() + extra_toco_options.rnn_states = ( + "{state_array:rnn/BasicLSTMCellZeroState/zeros," + "back_edge_source_array:rnn/basic_lstm_cell/Add_1,size:4}," + "{state_array:rnn/BasicLSTMCellZeroState/zeros_1," + "back_edge_source_array:rnn/basic_lstm_cell/Mul_2,size:4}") + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + extra_toco_options, + use_frozen_graph=True) + + def make_l2_pool(input_tensor, ksize, strides, padding, data_format): """Given an input perform a sequence of TensorFlow ops to produce l2pool.""" return tf.sqrt(tf.nn.avg_pool( @@ -1850,6 +1927,7 @@ def main(unused_args): "strided_slice.zip": make_strided_slice_tests, "exp.zip": make_exp_tests, "log_softmax.zip": make_log_softmax_tests, + "lstm.zip": make_lstm_tests, } out = FLAGS.zip_to_output bin_path = FLAGS.toco diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 89a5841371c..976363fd444 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -266,6 +266,7 @@ INSTANTIATE_TESTS(sub) INSTANTIATE_TESTS(split) INSTANTIATE_TESTS(div) INSTANTIATE_TESTS(transpose) +INSTANTIATE_TESTS(lstm) INSTANTIATE_TESTS(mean) INSTANTIATE_TESTS(squeeze) INSTANTIATE_TESTS(strided_slice) diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc index c8f2e49f930..389688d5520 100644 --- a/tensorflow/contrib/lite/testing/parse_testdata.cc +++ b/tensorflow/contrib/lite/testing/parse_testdata.cc @@ -192,27 +192,25 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, int model_outputs = interpreter->outputs().size(); TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size()); for (size_t i = 0; i < interpreter->outputs().size(); i++) { + bool tensors_differ = false; int output_index = interpreter->outputs()[i]; if (const float* data = interpreter->typed_tensor(output_index)) { for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { float computed = data[idx]; float reference = example.outputs[0].flat_data[idx]; float diff = std::abs(computed - reference); - bool error_is_large = false; // For very small numbers, try absolute error, otherwise go with // relative. - if (std::abs(reference) < kRelativeThreshold) { - error_is_large = (diff > kAbsoluteThreshold); - } else { - error_is_large = (diff > kRelativeThreshold * std::abs(reference)); - } - if (error_is_large) { + bool local_tensors_differ = + std::abs(reference) < kRelativeThreshold + ? diff > kAbsoluteThreshold + : diff > kRelativeThreshold * std::abs(reference); + if (local_tensors_differ) { fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n", i, idx, data[idx], reference); - return kTfLiteError; + tensors_differ = local_tensors_differ; } } - fprintf(stderr, "\n"); } else if (const int32_t* data = interpreter->typed_tensor(output_index)) { for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { @@ -221,10 +219,9 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, if (std::abs(computed - reference) > 0) { fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %d\n", i, idx, computed, reference); - return kTfLiteError; + tensors_differ = true; } } - fprintf(stderr, "\n"); } else if (const int64_t* data = interpreter->typed_tensor(output_index)) { for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { @@ -235,14 +232,15 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, "output[%zu][%zu] did not match %" PRId64 " vs reference %" PRId64 "\n", i, idx, computed, reference); - return kTfLiteError; + tensors_differ = true; } } - fprintf(stderr, "\n"); } else { fprintf(stderr, "output[%zu] was not float or int data\n", i); return kTfLiteError; } + fprintf(stderr, "\n"); + if (tensors_differ) return kTfLiteError; } return kTfLiteOk; }