In Toco, import Snapshot as Identity op.
PiperOrigin-RevId: 239987762
This commit is contained in:
parent
b481ac8b01
commit
22ac557a38
@ -261,6 +261,7 @@ def generated_test_models():
|
||||
"global_batch_norm",
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"identity",
|
||||
"sum",
|
||||
"l2norm",
|
||||
"l2norm_shared_epsilon",
|
||||
|
@ -55,6 +55,8 @@ from tensorflow.lite.testing import string_util_wrapper
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import graph_util as tf_graph_util
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
RANDOM_SEED = 342
|
||||
TEST_INPUT_DEPTH = 3
|
||||
@ -739,6 +741,40 @@ def make_elu_tests(options):
|
||||
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_identity_tests(options):
|
||||
"""Make a set of tests to do relu."""
|
||||
|
||||
# Chose a set of parameters
|
||||
test_parameters = [{
|
||||
"input_shape": [[], [1], [3, 3]],
|
||||
"use_snapshot": [False, True],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
input_tensor = tf.placeholder(
|
||||
dtype=tf.float32, name="input", shape=parameters["input_shape"])
|
||||
# Toco crashes when the model has only one single Identity op. As a
|
||||
# workaround for testing, we put MULs before and after the identity.
|
||||
# TODO(b/129197312): Remove the workaround after the issue is fixed.
|
||||
input_doubled = input_tensor * 2.0
|
||||
if parameters["use_snapshot"]:
|
||||
identity_output = array_ops.snapshot(input_tensor)
|
||||
else:
|
||||
identity_output = tf.identity(input_tensor)
|
||||
out = identity_output * 2.0
|
||||
return [input_tensor], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_values = create_tensor_data(
|
||||
np.float32, parameters["input_shape"], min_value=-4, max_value=10)
|
||||
return [input_values], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_values])))
|
||||
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_relu_tests(options):
|
||||
"""Make a set of tests to do relu."""
|
||||
|
@ -838,7 +838,8 @@ tensorflow::Status ConvertIdentityOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
|
||||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient");
|
||||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
|
||||
node.op() == "Snapshot");
|
||||
auto* op = new TensorFlowIdentityOperator;
|
||||
// Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
|
||||
// identity nodes with multiple inputs, but the other inputs seem
|
||||
@ -2522,6 +2523,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
|
||||
{"SquaredDifference",
|
||||
ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
|
||||
{"Snapshot", ConvertIdentityOperator},
|
||||
{"Squeeze", ConvertSqueezeOperator},
|
||||
{"StopGradient", ConvertIdentityOperator},
|
||||
{"StridedSlice", ConvertStridedSliceOperator},
|
||||
|
Loading…
Reference in New Issue
Block a user