[Grappler] Fix bug in control overrides (function optimizer)

PiperOrigin-RevId: 224075171
This commit is contained in:
Eugene Zhulenev 2018-12-04 17:27:09 -08:00 committed by TensorFlower Gardener
parent 4efd674dab
commit 513f07b954
2 changed files with 9 additions and 5 deletions

View File

@ -1665,7 +1665,7 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
gtl::FlatSet<string> add_ctrl_inputs;
// Remove all invalidated control inputs.
for (int idx = 0; idx < node.input_size(); ++idx) {
for (int idx = 0; idx < node.input_size(); /* see below */) {
// TODO(ezhulenev): Use non-allocating TensorId after migrating
// `control_overrides()` to absl::flat_hash_set.
SafeTensorId input_tensor = ParseTensorName(node.input(idx));
@ -1685,6 +1685,10 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
for (const string& override : overrides->second) {
add_ctrl_inputs.insert(AsControlDependency(override));
}
} else {
// Go to the next input only if the current one was not invalidated,
// otherwise we need to check the swapped input as well.
++idx;
}
}

View File

@ -812,8 +812,8 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
// Return result of multiplication and a current value of the variable.
NDef("out_1", "Identity", {"f2"}, {{"T", DT_FLOAT}}, kDevice),
NDef("out_2", "ReadVariableOp", {"v", "^f2"}, {{"dtype", DT_FLOAT}},
kDevice)},
NDef("out_2", "ReadVariableOp", {"v", "^f1", "^f2"},
{{"dtype", DT_FLOAT}}, kDevice)},
// Function library.
{mul_func});
@ -860,8 +860,8 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
// Return values read directly from inlined nodes.
NDef("out_1", "Identity", {"f2/mul:0"}, {{"T", DT_FLOAT}}, kDevice),
NDef("out_2", "ReadVariableOp", {"v", "^f2/add"}, {{"dtype", DT_FLOAT}},
kDevice)},
NDef("out_2", "ReadVariableOp", {"v", "^f1/add", "^f2/add"},
{{"dtype", DT_FLOAT}}, kDevice)},
// Function library.
{mul_func});