Added a RunReplicated() method to HloTestBase.
PiperOrigin-RevId: 333359479 Change-Id: Ied332dd75d734bd7017369e99820205bf5ec6842
This commit is contained in:
parent
b65ac36a29
commit
588c8db0c2
@ -414,6 +414,47 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
|
||||
: ::testing::AssertionFailure() << output.status().error_message();
|
||||
}
|
||||
|
||||
::testing::AssertionResult HloTestBase::RunReplicated(string_view hlo_string,
|
||||
bool run_hlo_passes,
|
||||
int64 num_replicas,
|
||||
string backend_config) {
|
||||
auto module_or_status =
|
||||
ParseAndReturnVerifiedModule(hlo_string, num_replicas);
|
||||
if (!module_or_status.ok()) {
|
||||
return ::testing::AssertionFailure()
|
||||
<< "Error while parsing HLO text format: "
|
||||
<< module_or_status.status().ToString();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
|
||||
const auto& fake_arguments =
|
||||
MakeFakeArguments(module.get()).ConsumeValueOrDie();
|
||||
std::vector<Literal*> fake_argument_ptrs;
|
||||
absl::c_transform(
|
||||
fake_arguments, std::back_inserter(fake_argument_ptrs),
|
||||
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
|
||||
|
||||
if (!backend_config.empty()) {
|
||||
// Set backend configuration if it is given.
|
||||
HloInstruction* instruction =
|
||||
module->entry_computation()->root_instruction();
|
||||
instruction->set_raw_backend_config_string(backend_config);
|
||||
}
|
||||
|
||||
HloRunner::ReplicatedExecuteOptions options;
|
||||
options.num_replicas = num_replicas;
|
||||
options.run_hlo_passes = run_hlo_passes;
|
||||
options.use_threads = true;
|
||||
for (auto argument : fake_argument_ptrs) {
|
||||
options.arguments.push_back(argument);
|
||||
}
|
||||
auto output = test_runner_.ExecuteReplicated(std::move(module), options);
|
||||
|
||||
return output.ok()
|
||||
? ::testing::AssertionSuccess()
|
||||
: ::testing::AssertionFailure() << output.status().error_message();
|
||||
}
|
||||
|
||||
::testing::AssertionResult HloTestBase::RunMultipleTimes(
|
||||
string_view hlo_string, bool run_hlo_passes,
|
||||
std::vector<ExecutionProfile>* profiles, string backend_config,
|
||||
|
@ -234,6 +234,11 @@ class HloTestBase : public ManifestCheckingTest {
|
||||
ExecutionProfile* profile = nullptr,
|
||||
string backend_config = "") TF_MUST_USE_RESULT;
|
||||
|
||||
// Executes an hlo module with fake inputs on multiple replicas.
|
||||
::testing::AssertionResult RunReplicated(
|
||||
const absl::string_view hlo_string, bool run_hlo_passes = true,
|
||||
int64 num_replicas = 1, string backend_config = "") TF_MUST_USE_RESULT;
|
||||
|
||||
// If assert_determinism is true, the assertion will fail unless all runs
|
||||
// produce exactly the same output.
|
||||
::testing::AssertionResult RunMultipleTimes(
|
||||
|
Loading…
x
Reference in New Issue
Block a user