Added a RunReplicated() method to HloTestBase.

PiperOrigin-RevId: 333359479
Change-Id: Ied332dd75d734bd7017369e99820205bf5ec6842
This commit is contained in:
Gaurav Agrawal 2020-09-23 13:21:26 -07:00 committed by TensorFlower Gardener
parent b65ac36a29
commit 588c8db0c2
2 changed files with 46 additions and 0 deletions

View File

@ -414,6 +414,47 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
: ::testing::AssertionFailure() << output.status().error_message();
}
::testing::AssertionResult HloTestBase::RunReplicated(string_view hlo_string,
bool run_hlo_passes,
int64 num_replicas,
string backend_config) {
auto module_or_status =
ParseAndReturnVerifiedModule(hlo_string, num_replicas);
if (!module_or_status.ok()) {
return ::testing::AssertionFailure()
<< "Error while parsing HLO text format: "
<< module_or_status.status().ToString();
}
std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
const auto& fake_arguments =
MakeFakeArguments(module.get()).ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
if (!backend_config.empty()) {
// Set backend configuration if it is given.
HloInstruction* instruction =
module->entry_computation()->root_instruction();
instruction->set_raw_backend_config_string(backend_config);
}
HloRunner::ReplicatedExecuteOptions options;
options.num_replicas = num_replicas;
options.run_hlo_passes = run_hlo_passes;
options.use_threads = true;
for (auto argument : fake_argument_ptrs) {
options.arguments.push_back(argument);
}
auto output = test_runner_.ExecuteReplicated(std::move(module), options);
return output.ok()
? ::testing::AssertionSuccess()
: ::testing::AssertionFailure() << output.status().error_message();
}
::testing::AssertionResult HloTestBase::RunMultipleTimes(
string_view hlo_string, bool run_hlo_passes,
std::vector<ExecutionProfile>* profiles, string backend_config,

View File

@ -234,6 +234,11 @@ class HloTestBase : public ManifestCheckingTest {
ExecutionProfile* profile = nullptr,
string backend_config = "") TF_MUST_USE_RESULT;
// Executes an hlo module with fake inputs on multiple replicas.
::testing::AssertionResult RunReplicated(
const absl::string_view hlo_string, bool run_hlo_passes = true,
int64 num_replicas = 1, string backend_config = "") TF_MUST_USE_RESULT;
// If assert_determinism is true, the assertion will fail unless all runs
// produce exactly the same output.
::testing::AssertionResult RunMultipleTimes(