In Toco, import Snapshot as Identity op.

PiperOrigin-RevId: 239987762
This commit is contained in:
Yu-Cheng Ling 2019-03-23 20:18:48 -07:00 committed by TensorFlower Gardener
parent b481ac8b01
commit 22ac557a38
3 changed files with 40 additions and 1 deletions

View File

@ -261,6 +261,7 @@ def generated_test_models():
"global_batch_norm", "global_batch_norm",
"greater", "greater",
"greater_equal", "greater_equal",
"identity",
"sum", "sum",
"l2norm", "l2norm",
"l2norm_shared_epsilon", "l2norm_shared_epsilon",

View File

@ -55,6 +55,8 @@ from tensorflow.lite.testing import string_util_wrapper
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn
from tensorflow.python.ops import array_ops
RANDOM_SEED = 342 RANDOM_SEED = 342
TEST_INPUT_DEPTH = 3 TEST_INPUT_DEPTH = 3
@ -739,6 +741,40 @@ def make_elu_tests(options):
make_zip_of_tests(options, test_parameters, build_graph, build_inputs) make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
@register_make_test_function()
def make_identity_tests(options):
"""Make a set of tests to do relu."""
# Chose a set of parameters
test_parameters = [{
"input_shape": [[], [1], [3, 3]],
"use_snapshot": [False, True],
}]
def build_graph(parameters):
input_tensor = tf.placeholder(
dtype=tf.float32, name="input", shape=parameters["input_shape"])
# Toco crashes when the model has only one single Identity op. As a
# workaround for testing, we put MULs before and after the identity.
# TODO(b/129197312): Remove the workaround after the issue is fixed.
input_doubled = input_tensor * 2.0
if parameters["use_snapshot"]:
identity_output = array_ops.snapshot(input_tensor)
else:
identity_output = tf.identity(input_tensor)
out = identity_output * 2.0
return [input_tensor], [out]
def build_inputs(parameters, sess, inputs, outputs):
input_values = create_tensor_data(
np.float32, parameters["input_shape"], min_value=-4, max_value=10)
return [input_values], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_values])))
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
@register_make_test_function() @register_make_test_function()
def make_relu_tests(options): def make_relu_tests(options):
"""Make a set of tests to do relu.""" """Make a set of tests to do relu."""

View File

@ -838,7 +838,8 @@ tensorflow::Status ConvertIdentityOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) { Model* model) {
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient"); node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
node.op() == "Snapshot");
auto* op = new TensorFlowIdentityOperator; auto* op = new TensorFlowIdentityOperator;
// Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
// identity nodes with multiple inputs, but the other inputs seem // identity nodes with multiple inputs, but the other inputs seem
@ -2522,6 +2523,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>}, {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
{"SquaredDifference", {"SquaredDifference",
ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>}, ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
{"Snapshot", ConvertIdentityOperator},
{"Squeeze", ConvertSqueezeOperator}, {"Squeeze", ConvertSqueezeOperator},
{"StopGradient", ConvertIdentityOperator}, {"StopGradient", ConvertIdentityOperator},
{"StridedSlice", ConvertStridedSliceOperator}, {"StridedSlice", ConvertStridedSliceOperator},