From 568f90e6c1b41f4f57dbb574cd998cdc0ca57734 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Mar 2020 10:43:26 -0700 Subject: [PATCH] Add HloTestBase::Run function that takes HloModule. PiperOrigin-RevId: 303146442 Change-Id: I2ed043b5c65442e05b5c36e344161f18623734d6 --- tensorflow/compiler/xla/tests/hlo_test_base.cc | 16 ++++++++++++++++ tensorflow/compiler/xla/tests/hlo_test_base.h | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 1a1dda80f18..64d586a9514 100755 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -312,6 +312,22 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } +::testing::AssertionResult HloTestBase::Run(std::unique_ptr module, + bool run_hlo_passes) { + const auto fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + const auto change = hlo_verifier_->Run(module.get()); + if (!change.ok()) { + return ::testing::AssertionFailure() << change.status(); + } + + const auto output = + test_runner_.Execute(std::move(module), fake_arguments, run_hlo_passes); + return output.ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << output.status().error_message(); +} + ::testing::AssertionResult HloTestBase::RunAndCompare( string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index eebe26ecde5..0b1801ebe23 100755 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -203,6 +203,11 @@ class HloTestBase : public ::testing::Test { const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; + // Executes an hlo module with fake inputs and checks that the execution is + // successful. + ::testing::AssertionResult Run(std::unique_ptr module, + bool run_hlo_passes) TF_MUST_USE_RESULT; + // Convenient wrappers for executing and comparing an hlo module with fake // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file.