[Grappler] Fix bug in control overrides (function optimizer)
PiperOrigin-RevId: 224075171
This commit is contained in:
parent
4efd674dab
commit
513f07b954
tensorflow/core/grappler/optimizers
@ -1665,7 +1665,7 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
|
|||||||
gtl::FlatSet<string> add_ctrl_inputs;
|
gtl::FlatSet<string> add_ctrl_inputs;
|
||||||
|
|
||||||
// Remove all invalidated control inputs.
|
// Remove all invalidated control inputs.
|
||||||
for (int idx = 0; idx < node.input_size(); ++idx) {
|
for (int idx = 0; idx < node.input_size(); /* see below */) {
|
||||||
// TODO(ezhulenev): Use non-allocating TensorId after migrating
|
// TODO(ezhulenev): Use non-allocating TensorId after migrating
|
||||||
// `control_overrides()` to absl::flat_hash_set.
|
// `control_overrides()` to absl::flat_hash_set.
|
||||||
SafeTensorId input_tensor = ParseTensorName(node.input(idx));
|
SafeTensorId input_tensor = ParseTensorName(node.input(idx));
|
||||||
@ -1685,6 +1685,10 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
|
|||||||
for (const string& override : overrides->second) {
|
for (const string& override : overrides->second) {
|
||||||
add_ctrl_inputs.insert(AsControlDependency(override));
|
add_ctrl_inputs.insert(AsControlDependency(override));
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Go to the next input only if the current one was not invalidated,
|
||||||
|
// otherwise we need to check the swapped input as well.
|
||||||
|
++idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -812,8 +812,8 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
|
|||||||
|
|
||||||
// Return result of multiplication and a current value of the variable.
|
// Return result of multiplication and a current value of the variable.
|
||||||
NDef("out_1", "Identity", {"f2"}, {{"T", DT_FLOAT}}, kDevice),
|
NDef("out_1", "Identity", {"f2"}, {{"T", DT_FLOAT}}, kDevice),
|
||||||
NDef("out_2", "ReadVariableOp", {"v", "^f2"}, {{"dtype", DT_FLOAT}},
|
NDef("out_2", "ReadVariableOp", {"v", "^f1", "^f2"},
|
||||||
kDevice)},
|
{{"dtype", DT_FLOAT}}, kDevice)},
|
||||||
|
|
||||||
// Function library.
|
// Function library.
|
||||||
{mul_func});
|
{mul_func});
|
||||||
@ -860,8 +860,8 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
|
|||||||
|
|
||||||
// Return values read directly from inlined nodes.
|
// Return values read directly from inlined nodes.
|
||||||
NDef("out_1", "Identity", {"f2/mul:0"}, {{"T", DT_FLOAT}}, kDevice),
|
NDef("out_1", "Identity", {"f2/mul:0"}, {{"T", DT_FLOAT}}, kDevice),
|
||||||
NDef("out_2", "ReadVariableOp", {"v", "^f2/add"}, {{"dtype", DT_FLOAT}},
|
NDef("out_2", "ReadVariableOp", {"v", "^f1/add", "^f2/add"},
|
||||||
kDevice)},
|
{{"dtype", DT_FLOAT}}, kDevice)},
|
||||||
|
|
||||||
// Function library.
|
// Function library.
|
||||||
{mul_func});
|
{mul_func});
|
||||||
|
Loading…
Reference in New Issue
Block a user