diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h index b01f6b57e64..ec26dfec1d2 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -63,7 +63,7 @@ class MlirOptimizationPassRegistry { } }; - using Passes = std::set; + using Passes = std::multiset; // Returns the global registry of MLIR optimization passes. static MlirOptimizationPassRegistry& Global(); @@ -145,7 +145,7 @@ class MlirV1CompatOptimizationPassRegistry { } }; - using Passes = std::set; + using Passes = std::multiset; // Returns the global registry of MLIR optimization passes. static MlirV1CompatOptimizationPassRegistry& Global(); diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index 74992f67532..371719cb319 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -118,4 +118,13 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) { #endif } +TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriority) { + MlirOptimizationPassRegistry::Global().Add( + 0, std::make_unique>()); + MlirOptimizationPassRegistry::Global().Add( + 0, std::make_unique>()); + + EXPECT_EQ(MlirOptimizationPassRegistry::Global().passes().size(), 2); +} + } // namespace tensorflow