diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 69685409a35..9c25eb08d8f 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -1665,7 +1665,7 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item, gtl::FlatSet add_ctrl_inputs; // Remove all invalidated control inputs. - for (int idx = 0; idx < node.input_size(); ++idx) { + for (int idx = 0; idx < node.input_size(); /* see below */) { // TODO(ezhulenev): Use non-allocating TensorId after migrating // `control_overrides()` to absl::flat_hash_set. SafeTensorId input_tensor = ParseTensorName(node.input(idx)); @@ -1685,6 +1685,10 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item, for (const string& override : overrides->second) { add_ctrl_inputs.insert(AsControlDependency(override)); } + } else { + // Go to the next input only if the current one was not invalidated, + // otherwise we need to check the swapped input as well. + ++idx; } } diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index 93a2fcda7bf..de091dbe98c 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -812,8 +812,8 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) { // Return result of multiplication and a current value of the variable. NDef("out_1", "Identity", {"f2"}, {{"T", DT_FLOAT}}, kDevice), - NDef("out_2", "ReadVariableOp", {"v", "^f2"}, {{"dtype", DT_FLOAT}}, - kDevice)}, + NDef("out_2", "ReadVariableOp", {"v", "^f1", "^f2"}, + {{"dtype", DT_FLOAT}}, kDevice)}, // Function library. {mul_func}); @@ -860,8 +860,8 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) { // Return values read directly from inlined nodes. NDef("out_1", "Identity", {"f2/mul:0"}, {{"T", DT_FLOAT}}, kDevice), - NDef("out_2", "ReadVariableOp", {"v", "^f2/add"}, {{"dtype", DT_FLOAT}}, - kDevice)}, + NDef("out_2", "ReadVariableOp", {"v", "^f1/add", "^f2/add"}, + {{"dtype", DT_FLOAT}}, kDevice)}, // Function library. {mul_func});