diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 79974723b8b..64f5440b99f 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -212,9 +212,10 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, StatusOr> HloTestBase::ExecuteReplicated( std::unique_ptr module, absl::Span arguments, - int64 num_replicas, bool use_threads) { + int64 num_replicas, bool use_threads, bool run_hlo_passes) { HloRunner::ReplicatedExecuteOptions options; options.num_replicas = num_replicas; + options.run_hlo_passes = run_hlo_passes; options.use_threads = use_threads; for (auto argument : arguments) { options.arguments.push_back(argument); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 7a78307a467..d4a1788c928 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -181,7 +181,7 @@ class HloTestBase : public ::testing::Test { // Executable::ExecuteOnStreams. StatusOr> ExecuteReplicated( std::unique_ptr module, absl::Span arguments, - int64 num_replicas, bool use_threads); + int64 num_replicas, bool use_threads, bool run_hlo_passes = false); // Same as above, but uses specified device assignment. StatusOr> ExecuteReplicated(