From a73e0bb98a1e37022e688d3fa8a7aedf636f7059 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Wed, 13 Nov 2019 19:04:56 -0800 Subject: [PATCH] Fix ophinted model with pure identity nodes as outputs. PiperOrigin-RevId: 280331023 Change-Id: Ica0b741483500ff8b900c766e4b3dc9f4bae9bf5 --- tensorflow/lite/python/util.py | 1 - tensorflow/lite/testing/op_tests/unidirectional_sequence_lstm.py | 1 + tensorflow/lite/testing/op_tests/unidirectional_sequence_rnn.py | 1 + 3 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index 9cb8e6e535f..3c1630acb6f 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -226,7 +226,6 @@ def _convert_op_hints_if_present(sess, graph_def, output_tensors, graph_def = tf_graph_util.convert_variables_to_constants( sess, graph_def, output_arrays + hinted_outputs_nodes) graph_def = convert_op_hints_to_stubs(graph_def=graph_def) - graph_def = tf_graph_util.remove_training_nodes(graph_def) return graph_def diff --git a/tensorflow/lite/testing/op_tests/unidirectional_sequence_lstm.py b/tensorflow/lite/testing/op_tests/unidirectional_sequence_lstm.py index f3221e28477..f82ce53ea8e 100644 --- a/tensorflow/lite/testing/op_tests/unidirectional_sequence_lstm.py +++ b/tensorflow/lite/testing/op_tests/unidirectional_sequence_lstm.py @@ -64,6 +64,7 @@ def make_unidirectional_sequence_lstm_tests(options): outs, _ = tf.nn.static_rnn(lstm_cell, input_values, dtype=tf.float32) real_output = tf.zeros([1], dtype=tf.float32) + outs[-1] + real_output = tf.identity(real_output) return input_values, [real_output] def build_inputs(parameters, sess, inputs, outputs): diff --git a/tensorflow/lite/testing/op_tests/unidirectional_sequence_rnn.py b/tensorflow/lite/testing/op_tests/unidirectional_sequence_rnn.py index 5bab7673b03..80966905b4c 100644 --- a/tensorflow/lite/testing/op_tests/unidirectional_sequence_rnn.py +++ b/tensorflow/lite/testing/op_tests/unidirectional_sequence_rnn.py @@ -61,6 +61,7 @@ def make_unidirectional_sequence_rnn_tests(options): outs, _ = tf.nn.static_rnn(rnn_cell, input_values, dtype=tf.float32) real_output = tf.zeros([1], dtype=tf.float32) + outs[-1] + real_output = tf.identity(real_output) return input_values, [real_output] def build_inputs(parameters, sess, inputs, outputs):