Re-evaluate the iterator after we modify the underlying operators vector. Not doing so might cause segfault when we reference the old iterator again.
PiperOrigin-RevId: 241666580
This commit is contained in:
		
							parent
							
								
									1597edbeed
								
							
						
					
					
						commit
						a7575051f1
					
				@ -749,7 +749,7 @@ def make_elu_tests(options):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
@register_make_test_function()
 | 
					@register_make_test_function()
 | 
				
			||||||
def make_identity_tests(options):
 | 
					def make_identity_tests(options):
 | 
				
			||||||
  """Make a set of tests to do relu."""
 | 
					  """Make a set of tests to do identity."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Chose a set of parameters
 | 
					  # Chose a set of parameters
 | 
				
			||||||
  test_parameters = [{
 | 
					  test_parameters = [{
 | 
				
			||||||
@ -760,16 +760,11 @@ def make_identity_tests(options):
 | 
				
			|||||||
  def build_graph(parameters):
 | 
					  def build_graph(parameters):
 | 
				
			||||||
    input_tensor = tf.placeholder(
 | 
					    input_tensor = tf.placeholder(
 | 
				
			||||||
        dtype=tf.float32, name="input", shape=parameters["input_shape"])
 | 
					        dtype=tf.float32, name="input", shape=parameters["input_shape"])
 | 
				
			||||||
    # Toco crashes when the model has only one single Identity op. As a
 | 
					 | 
				
			||||||
    # workaround for testing, we put MULs before and after the identity.
 | 
					 | 
				
			||||||
    # TODO(b/129197312): Remove the workaround after the issue is fixed.
 | 
					 | 
				
			||||||
    input_doubled = input_tensor * 2.0
 | 
					 | 
				
			||||||
    if parameters["use_snapshot"]:
 | 
					    if parameters["use_snapshot"]:
 | 
				
			||||||
      identity_output = array_ops.snapshot(input_tensor)
 | 
					      identity_output = array_ops.snapshot(input_tensor)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
      identity_output = tf.identity(input_tensor)
 | 
					      identity_output = tf.identity(input_tensor)
 | 
				
			||||||
    out = identity_output * 2.0
 | 
					    return [input_tensor], [identity_output]
 | 
				
			||||||
    return [input_tensor], [out]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def build_inputs(parameters, sess, inputs, outputs):
 | 
					  def build_inputs(parameters, sess, inputs, outputs):
 | 
				
			||||||
    input_values = create_tensor_data(
 | 
					    input_values = create_tensor_data(
 | 
				
			||||||
 | 
				
			|||||||
@ -64,7 +64,7 @@ void Reroute(const string& from, const string& to, Model* model) {
 | 
				
			|||||||
bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
 | 
					bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
 | 
				
			||||||
                                Model* model, std::size_t op_index,
 | 
					                                Model* model, std::size_t op_index,
 | 
				
			||||||
                                int input_index) {
 | 
					                                int input_index) {
 | 
				
			||||||
  const auto passthru_it = model->operators.begin() + op_index;
 | 
					  auto passthru_it = model->operators.begin() + op_index;
 | 
				
			||||||
  auto* passthru_op = passthru_it->get();
 | 
					  auto* passthru_op = passthru_it->get();
 | 
				
			||||||
  CHECK_EQ(passthru_op->outputs.size(), 1);
 | 
					  CHECK_EQ(passthru_op->outputs.size(), 1);
 | 
				
			||||||
  CHECK_GE(passthru_op->inputs.size(), 1);
 | 
					  CHECK_GE(passthru_op->inputs.size(), 1);
 | 
				
			||||||
@ -127,6 +127,8 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
 | 
				
			|||||||
      // copy itself is a trivial reshape and we'd go into an infinite loop!
 | 
					      // copy itself is a trivial reshape and we'd go into an infinite loop!
 | 
				
			||||||
      transformation->AddMessageF("Replacing with a copy (reshape) instead");
 | 
					      transformation->AddMessageF("Replacing with a copy (reshape) instead");
 | 
				
			||||||
      InsertCopyOperator(model, main_input_name, output_name);
 | 
					      InsertCopyOperator(model, main_input_name, output_name);
 | 
				
			||||||
 | 
					      // To avoid using invalidated iterator, evaluate passthru_it again.
 | 
				
			||||||
 | 
					      passthru_it = model->operators.begin() + op_index;
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      return false;
 | 
					      return false;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user