In Toco, import Snapshot as Identity op.
PiperOrigin-RevId: 239987762
This commit is contained in:
parent
b481ac8b01
commit
22ac557a38
@ -261,6 +261,7 @@ def generated_test_models():
|
|||||||
"global_batch_norm",
|
"global_batch_norm",
|
||||||
"greater",
|
"greater",
|
||||||
"greater_equal",
|
"greater_equal",
|
||||||
|
"identity",
|
||||||
"sum",
|
"sum",
|
||||||
"l2norm",
|
"l2norm",
|
||||||
"l2norm_shared_epsilon",
|
"l2norm_shared_epsilon",
|
||||||
|
@ -55,6 +55,8 @@ from tensorflow.lite.testing import string_util_wrapper
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.framework import graph_util as tf_graph_util
|
from tensorflow.python.framework import graph_util as tf_graph_util
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
|
||||||
|
|
||||||
RANDOM_SEED = 342
|
RANDOM_SEED = 342
|
||||||
TEST_INPUT_DEPTH = 3
|
TEST_INPUT_DEPTH = 3
|
||||||
@ -739,6 +741,40 @@ def make_elu_tests(options):
|
|||||||
|
|
||||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_make_test_function()
|
||||||
|
def make_identity_tests(options):
|
||||||
|
"""Make a set of tests to do relu."""
|
||||||
|
|
||||||
|
# Chose a set of parameters
|
||||||
|
test_parameters = [{
|
||||||
|
"input_shape": [[], [1], [3, 3]],
|
||||||
|
"use_snapshot": [False, True],
|
||||||
|
}]
|
||||||
|
|
||||||
|
def build_graph(parameters):
|
||||||
|
input_tensor = tf.placeholder(
|
||||||
|
dtype=tf.float32, name="input", shape=parameters["input_shape"])
|
||||||
|
# Toco crashes when the model has only one single Identity op. As a
|
||||||
|
# workaround for testing, we put MULs before and after the identity.
|
||||||
|
# TODO(b/129197312): Remove the workaround after the issue is fixed.
|
||||||
|
input_doubled = input_tensor * 2.0
|
||||||
|
if parameters["use_snapshot"]:
|
||||||
|
identity_output = array_ops.snapshot(input_tensor)
|
||||||
|
else:
|
||||||
|
identity_output = tf.identity(input_tensor)
|
||||||
|
out = identity_output * 2.0
|
||||||
|
return [input_tensor], [out]
|
||||||
|
|
||||||
|
def build_inputs(parameters, sess, inputs, outputs):
|
||||||
|
input_values = create_tensor_data(
|
||||||
|
np.float32, parameters["input_shape"], min_value=-4, max_value=10)
|
||||||
|
return [input_values], sess.run(
|
||||||
|
outputs, feed_dict=dict(zip(inputs, [input_values])))
|
||||||
|
|
||||||
|
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||||
|
|
||||||
|
|
||||||
@register_make_test_function()
|
@register_make_test_function()
|
||||||
def make_relu_tests(options):
|
def make_relu_tests(options):
|
||||||
"""Make a set of tests to do relu."""
|
"""Make a set of tests to do relu."""
|
||||||
|
@ -838,7 +838,8 @@ tensorflow::Status ConvertIdentityOperator(
|
|||||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||||
Model* model) {
|
Model* model) {
|
||||||
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
|
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
|
||||||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient");
|
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
|
||||||
|
node.op() == "Snapshot");
|
||||||
auto* op = new TensorFlowIdentityOperator;
|
auto* op = new TensorFlowIdentityOperator;
|
||||||
// Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
|
// Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
|
||||||
// identity nodes with multiple inputs, but the other inputs seem
|
// identity nodes with multiple inputs, but the other inputs seem
|
||||||
@ -2522,6 +2523,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
|||||||
{"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
|
{"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
|
||||||
{"SquaredDifference",
|
{"SquaredDifference",
|
||||||
ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
|
ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
|
||||||
|
{"Snapshot", ConvertIdentityOperator},
|
||||||
{"Squeeze", ConvertSqueezeOperator},
|
{"Squeeze", ConvertSqueezeOperator},
|
||||||
{"StopGradient", ConvertIdentityOperator},
|
{"StopGradient", ConvertIdentityOperator},
|
||||||
{"StridedSlice", ConvertStridedSliceOperator},
|
{"StridedSlice", ConvertStridedSliceOperator},
|
||||||
|
Loading…
Reference in New Issue
Block a user