Re-evaluate the iterator after we modify the underlying operators vector. Not doing so might cause segfault when we reference the old iterator again.

PiperOrigin-RevId: 241666580
This commit is contained in:
Haoliang Zhang 2019-04-02 22:21:49 -07:00 committed by TensorFlower Gardener
parent 1597edbeed
commit a7575051f1
2 changed files with 5 additions and 8 deletions

View File

@ -749,7 +749,7 @@ def make_elu_tests(options):
@register_make_test_function()
def make_identity_tests(options):
"""Make a set of tests to do relu."""
"""Make a set of tests to do identity."""
# Chose a set of parameters
test_parameters = [{
@ -760,16 +760,11 @@ def make_identity_tests(options):
def build_graph(parameters):
input_tensor = tf.placeholder(
dtype=tf.float32, name="input", shape=parameters["input_shape"])
# Toco crashes when the model has only one single Identity op. As a
# workaround for testing, we put MULs before and after the identity.
# TODO(b/129197312): Remove the workaround after the issue is fixed.
input_doubled = input_tensor * 2.0
if parameters["use_snapshot"]:
identity_output = array_ops.snapshot(input_tensor)
else:
identity_output = tf.identity(input_tensor)
out = identity_output * 2.0
return [input_tensor], [out]
return [input_tensor], [identity_output]
def build_inputs(parameters, sess, inputs, outputs):
input_values = create_tensor_data(

View File

@ -64,7 +64,7 @@ void Reroute(const string& from, const string& to, Model* model) {
bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
Model* model, std::size_t op_index,
int input_index) {
const auto passthru_it = model->operators.begin() + op_index;
auto passthru_it = model->operators.begin() + op_index;
auto* passthru_op = passthru_it->get();
CHECK_EQ(passthru_op->outputs.size(), 1);
CHECK_GE(passthru_op->inputs.size(), 1);
@ -127,6 +127,8 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
// copy itself is a trivial reshape and we'd go into an infinite loop!
transformation->AddMessageF("Replacing with a copy (reshape) instead");
InsertCopyOperator(model, main_input_name, output_name);
// To avoid using invalidated iterator, evaluate passthru_it again.
passthru_it = model->operators.begin() + op_index;
} else {
return false;
}