Support x64<->x32 instructions in hlo replication analysis.
PiperOrigin-RevId: 334882856 Change-Id: Ic2699f2fa4513606300106e965187b9320045d2c
This commit is contained in:
parent
2e69a2dc42
commit
3a63cf6b99
@ -129,6 +129,13 @@ bool DetermineHloInstructionIsReplicated(
|
||||
return true;
|
||||
}
|
||||
|
||||
if (hlo->opcode() == HloOpcode::kCustomCall &&
|
||||
(hlo->custom_call_target() == "X64SplitLow" ||
|
||||
hlo->custom_call_target() == "X64SplitHigh" ||
|
||||
hlo->custom_call_target() == "X64Combine")) {
|
||||
return all_operands_replicated(hlo);
|
||||
}
|
||||
|
||||
if (hlo->IsElementwise() || //
|
||||
hlo->opcode() == HloOpcode::kConcatenate || //
|
||||
hlo->opcode() == HloOpcode::kConvolution || //
|
||||
|
@ -501,6 +501,36 @@ ENTRY entry {
|
||||
FindInstruction(module.get(), "conditional"), {1}));
|
||||
}
|
||||
|
||||
TEST_F(HloReplicationAnalysisTest, X64SplitCombine) {
|
||||
const string module_str = R"(
|
||||
HloModule SimpleTupleSelect
|
||||
|
||||
ENTRY entry {
|
||||
param = (f64[]) parameter(0)
|
||||
gte = f64[] get-tuple-element(param), index=0
|
||||
param-low = f32[] custom-call(gte), custom_call_target="X64SplitLow"
|
||||
param-high = f32[] custom-call(gte), custom_call_target="X64SplitHigh"
|
||||
ROOT result-combine = f64[] custom-call(param-low, param-high), custom_call_target="X64Combine"
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(module_str));
|
||||
auto param = module->entry_computation()->parameter_instruction(0);
|
||||
param->set_parameter_replicated_at_leaf_buffers(absl::Span<const bool>{true});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
|
||||
HloReplicationAnalysis::Run(
|
||||
module.get(), /*cross_partition_spmd=*/false));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "gte"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "param-low"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "param-high"), {}));
|
||||
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "result-combine"), {}));
|
||||
}
|
||||
|
||||
TEST_F(HloReplicationAnalysisTest, SimpleTupleSelect) {
|
||||
const string module_str = R"(
|
||||
HloModule SimpleTupleSelect
|
||||
|
Loading…
Reference in New Issue
Block a user