Add HloTestBase::Run function that takes HloModule.

PiperOrigin-RevId: 303146442
Change-Id: I2ed043b5c65442e05b5c36e344161f18623734d6
This commit is contained in:
A. Unique TensorFlower 2020-03-26 10:43:26 -07:00 committed by TensorFlower Gardener
parent c3e179cc7e
commit 568f90e6c1
2 changed files with 21 additions and 0 deletions

View File

@ -312,6 +312,22 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
reference_preprocessor);
}
::testing::AssertionResult HloTestBase::Run(std::unique_ptr<HloModule> module,
bool run_hlo_passes) {
const auto fake_arguments =
MakeFakeArguments(module.get()).ConsumeValueOrDie();
const auto change = hlo_verifier_->Run(module.get());
if (!change.ok()) {
return ::testing::AssertionFailure() << change.status();
}
const auto output =
test_runner_.Execute(std::move(module), fake_arguments, run_hlo_passes);
return output.ok()
? ::testing::AssertionSuccess()
: ::testing::AssertionFailure() << output.status().error_message();
}
::testing::AssertionResult HloTestBase::RunAndCompare(
string_view hlo_string, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {

View File

@ -203,6 +203,11 @@ class HloTestBase : public ::testing::Test {
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
// Executes an hlo module with fake inputs and checks that the execution is
// successful.
::testing::AssertionResult Run(std::unique_ptr<HloModule> module,
bool run_hlo_passes) TF_MUST_USE_RESULT;
// Convenient wrappers for executing and comparing an hlo module with fake
// input. Module can be passed in directly, or parsed from an hlo_string,
// or loaded from a file.