Added a RunReplicated() method to HloTestBase.
PiperOrigin-RevId: 333359479 Change-Id: Ied332dd75d734bd7017369e99820205bf5ec6842
This commit is contained in:
parent
b65ac36a29
commit
588c8db0c2
@ -414,6 +414,47 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
|
|||||||
: ::testing::AssertionFailure() << output.status().error_message();
|
: ::testing::AssertionFailure() << output.status().error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
::testing::AssertionResult HloTestBase::RunReplicated(string_view hlo_string,
|
||||||
|
bool run_hlo_passes,
|
||||||
|
int64 num_replicas,
|
||||||
|
string backend_config) {
|
||||||
|
auto module_or_status =
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string, num_replicas);
|
||||||
|
if (!module_or_status.ok()) {
|
||||||
|
return ::testing::AssertionFailure()
|
||||||
|
<< "Error while parsing HLO text format: "
|
||||||
|
<< module_or_status.status().ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<HloModule> module = std::move(module_or_status.ValueOrDie());
|
||||||
|
const auto& fake_arguments =
|
||||||
|
MakeFakeArguments(module.get()).ConsumeValueOrDie();
|
||||||
|
std::vector<Literal*> fake_argument_ptrs;
|
||||||
|
absl::c_transform(
|
||||||
|
fake_arguments, std::back_inserter(fake_argument_ptrs),
|
||||||
|
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
|
||||||
|
|
||||||
|
if (!backend_config.empty()) {
|
||||||
|
// Set backend configuration if it is given.
|
||||||
|
HloInstruction* instruction =
|
||||||
|
module->entry_computation()->root_instruction();
|
||||||
|
instruction->set_raw_backend_config_string(backend_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
HloRunner::ReplicatedExecuteOptions options;
|
||||||
|
options.num_replicas = num_replicas;
|
||||||
|
options.run_hlo_passes = run_hlo_passes;
|
||||||
|
options.use_threads = true;
|
||||||
|
for (auto argument : fake_argument_ptrs) {
|
||||||
|
options.arguments.push_back(argument);
|
||||||
|
}
|
||||||
|
auto output = test_runner_.ExecuteReplicated(std::move(module), options);
|
||||||
|
|
||||||
|
return output.ok()
|
||||||
|
? ::testing::AssertionSuccess()
|
||||||
|
: ::testing::AssertionFailure() << output.status().error_message();
|
||||||
|
}
|
||||||
|
|
||||||
::testing::AssertionResult HloTestBase::RunMultipleTimes(
|
::testing::AssertionResult HloTestBase::RunMultipleTimes(
|
||||||
string_view hlo_string, bool run_hlo_passes,
|
string_view hlo_string, bool run_hlo_passes,
|
||||||
std::vector<ExecutionProfile>* profiles, string backend_config,
|
std::vector<ExecutionProfile>* profiles, string backend_config,
|
||||||
|
@ -234,6 +234,11 @@ class HloTestBase : public ManifestCheckingTest {
|
|||||||
ExecutionProfile* profile = nullptr,
|
ExecutionProfile* profile = nullptr,
|
||||||
string backend_config = "") TF_MUST_USE_RESULT;
|
string backend_config = "") TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
|
// Executes an hlo module with fake inputs on multiple replicas.
|
||||||
|
::testing::AssertionResult RunReplicated(
|
||||||
|
const absl::string_view hlo_string, bool run_hlo_passes = true,
|
||||||
|
int64 num_replicas = 1, string backend_config = "") TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// If assert_determinism is true, the assertion will fail unless all runs
|
// If assert_determinism is true, the assertion will fail unless all runs
|
||||||
// produce exactly the same output.
|
// produce exactly the same output.
|
||||||
::testing::AssertionResult RunMultipleTimes(
|
::testing::AssertionResult RunMultipleTimes(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user