diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h index 0358d7f5409..d3ad43728f2 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -68,37 +68,6 @@ class ArithmeticOptimizerTest : public GrapplerTest { TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); } - // TODO(ezhulenev): Make private. After migration to stages each test - // should explicitly enable required optimization for tests isolation - void DisableAllStages(ArithmeticOptimizer* optimizer) { - ArithmeticOptimizer::ArithmeticOptimizerOptions options; - options.dedup_computations = false; - options.combine_add_to_addn = false; - options.convert_sqrt_div_to_rsqrt_mul = false; - options.convert_pow = false; - options.convert_log1p = false; - options.optimize_max_or_min_of_monotonic = false; - options.fold_conjugate_into_transpose = false; - options.fold_multiply_into_conv = false; - options.fold_transpose_into_matmul = false; - options.hoist_common_factor_out_of_aggregation = false; - options.hoist_cwise_unary_chains = false; - options.minimize_broadcasts = false; - options.remove_identity_transpose = false; - options.remove_involution = false; - options.remove_idempotent = false; - options.remove_redundant_bitcast = false; - options.remove_redundant_cast = false; - options.remove_redundant_reshape = false; - options.remove_negation = false; - options.remove_logical_not = false; - options.reorder_cast_like_and_value_preserving = false; - options.replace_mul_with_square = false; - options.simplify_aggregation = false; - options.unary_ops_composition = false; - optimizer->options_ = options; - } - void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) { optimizer->options_.combine_add_to_addn = false; } @@ -238,6 +207,36 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.remove_stack_strided_slice_same_axis = true; } + + private: + void DisableAllStages(ArithmeticOptimizer* optimizer) { + ArithmeticOptimizer::ArithmeticOptimizerOptions options; + options.dedup_computations = false; + options.combine_add_to_addn = false; + options.convert_sqrt_div_to_rsqrt_mul = false; + options.convert_pow = false; + options.convert_log1p = false; + options.optimize_max_or_min_of_monotonic = false; + options.fold_conjugate_into_transpose = false; + options.fold_multiply_into_conv = false; + options.fold_transpose_into_matmul = false; + options.hoist_common_factor_out_of_aggregation = false; + options.hoist_cwise_unary_chains = false; + options.minimize_broadcasts = false; + options.remove_identity_transpose = false; + options.remove_involution = false; + options.remove_idempotent = false; + options.remove_redundant_bitcast = false; + options.remove_redundant_cast = false; + options.remove_redundant_reshape = false; + options.remove_negation = false; + options.remove_logical_not = false; + options.reorder_cast_like_and_value_preserving = false; + options.replace_mul_with_square = false; + options.simplify_aggregation = false; + options.unary_ops_composition = false; + optimizer->options_ = options; + } }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index f48f5b01a79..de5257e3cef 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -60,13 +60,6 @@ class LoopOptimizerTest : public GrapplerTest { AddNode(name, op, inputs, attributes, graph); } - void DisableAllStages(LoopOptimizer* optimizer) { - LoopOptimizer::LoopOptimizerOptions options; - options.enable_loop_invariant_node_motion = false; - options.enable_stack_push_removal = false; - optimizer->options_ = options; - } - void EnableOnlyLoopInvariantNodeMotion(LoopOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.enable_loop_invariant_node_motion = true; @@ -76,6 +69,14 @@ class LoopOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.enable_stack_push_removal = true; } + + private: + void DisableAllStages(LoopOptimizer* optimizer) { + LoopOptimizer::LoopOptimizerOptions options; + options.enable_loop_invariant_node_motion = false; + options.enable_stack_push_removal = false; + optimizer->options_ = options; + } }; TEST_F(LoopOptimizerTest, Basic) {