Re-evaluate the iterator after we modify the underlying operators vector. Not doing so might cause segfault when we reference the old iterator again.
PiperOrigin-RevId: 241666580
This commit is contained in:
parent
1597edbeed
commit
a7575051f1
@ -749,7 +749,7 @@ def make_elu_tests(options):
|
|||||||
|
|
||||||
@register_make_test_function()
|
@register_make_test_function()
|
||||||
def make_identity_tests(options):
|
def make_identity_tests(options):
|
||||||
"""Make a set of tests to do relu."""
|
"""Make a set of tests to do identity."""
|
||||||
|
|
||||||
# Chose a set of parameters
|
# Chose a set of parameters
|
||||||
test_parameters = [{
|
test_parameters = [{
|
||||||
@ -760,16 +760,11 @@ def make_identity_tests(options):
|
|||||||
def build_graph(parameters):
|
def build_graph(parameters):
|
||||||
input_tensor = tf.placeholder(
|
input_tensor = tf.placeholder(
|
||||||
dtype=tf.float32, name="input", shape=parameters["input_shape"])
|
dtype=tf.float32, name="input", shape=parameters["input_shape"])
|
||||||
# Toco crashes when the model has only one single Identity op. As a
|
|
||||||
# workaround for testing, we put MULs before and after the identity.
|
|
||||||
# TODO(b/129197312): Remove the workaround after the issue is fixed.
|
|
||||||
input_doubled = input_tensor * 2.0
|
|
||||||
if parameters["use_snapshot"]:
|
if parameters["use_snapshot"]:
|
||||||
identity_output = array_ops.snapshot(input_tensor)
|
identity_output = array_ops.snapshot(input_tensor)
|
||||||
else:
|
else:
|
||||||
identity_output = tf.identity(input_tensor)
|
identity_output = tf.identity(input_tensor)
|
||||||
out = identity_output * 2.0
|
return [input_tensor], [identity_output]
|
||||||
return [input_tensor], [out]
|
|
||||||
|
|
||||||
def build_inputs(parameters, sess, inputs, outputs):
|
def build_inputs(parameters, sess, inputs, outputs):
|
||||||
input_values = create_tensor_data(
|
input_values = create_tensor_data(
|
||||||
|
|||||||
@ -64,7 +64,7 @@ void Reroute(const string& from, const string& to, Model* model) {
|
|||||||
bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
|
bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
|
||||||
Model* model, std::size_t op_index,
|
Model* model, std::size_t op_index,
|
||||||
int input_index) {
|
int input_index) {
|
||||||
const auto passthru_it = model->operators.begin() + op_index;
|
auto passthru_it = model->operators.begin() + op_index;
|
||||||
auto* passthru_op = passthru_it->get();
|
auto* passthru_op = passthru_it->get();
|
||||||
CHECK_EQ(passthru_op->outputs.size(), 1);
|
CHECK_EQ(passthru_op->outputs.size(), 1);
|
||||||
CHECK_GE(passthru_op->inputs.size(), 1);
|
CHECK_GE(passthru_op->inputs.size(), 1);
|
||||||
@ -127,6 +127,8 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
|
|||||||
// copy itself is a trivial reshape and we'd go into an infinite loop!
|
// copy itself is a trivial reshape and we'd go into an infinite loop!
|
||||||
transformation->AddMessageF("Replacing with a copy (reshape) instead");
|
transformation->AddMessageF("Replacing with a copy (reshape) instead");
|
||||||
InsertCopyOperator(model, main_input_name, output_name);
|
InsertCopyOperator(model, main_input_name, output_name);
|
||||||
|
// To avoid using invalidated iterator, evaluate passthru_it again.
|
||||||
|
passthru_it = model->operators.begin() + op_index;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user