Support x64<->x32 instructions in hlo replication analysis.

PiperOrigin-RevId: 334882856
Change-Id: Ic2699f2fa4513606300106e965187b9320045d2c
This commit is contained in:
Yunxing Dai 2020-10-01 13:16:27 -07:00 committed by TensorFlower Gardener
parent 2e69a2dc42
commit 3a63cf6b99
2 changed files with 37 additions and 0 deletions

View File

@ -129,6 +129,13 @@ bool DetermineHloInstructionIsReplicated(
return true;
}
if (hlo->opcode() == HloOpcode::kCustomCall &&
(hlo->custom_call_target() == "X64SplitLow" ||
hlo->custom_call_target() == "X64SplitHigh" ||
hlo->custom_call_target() == "X64Combine")) {
return all_operands_replicated(hlo);
}
if (hlo->IsElementwise() || //
hlo->opcode() == HloOpcode::kConcatenate || //
hlo->opcode() == HloOpcode::kConvolution || //

View File

@ -501,6 +501,36 @@ ENTRY entry {
FindInstruction(module.get(), "conditional"), {1}));
}
TEST_F(HloReplicationAnalysisTest, X64SplitCombine) {
const string module_str = R"(
HloModule SimpleTupleSelect
ENTRY entry {
param = (f64[]) parameter(0)
gte = f64[] get-tuple-element(param), index=0
param-low = f32[] custom-call(gte), custom_call_target="X64SplitLow"
param-high = f32[] custom-call(gte), custom_call_target="X64SplitHigh"
ROOT result-combine = f64[] custom-call(param-low, param-high), custom_call_target="X64Combine"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_str));
auto param = module->entry_computation()->parameter_instruction(0);
param->set_parameter_replicated_at_leaf_buffers(absl::Span<const bool>{true});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloReplicationAnalysis> analysis,
HloReplicationAnalysis::Run(
module.get(), /*cross_partition_spmd=*/false));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "gte"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "param-low"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "param-high"), {}));
EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "result-combine"), {}));
}
TEST_F(HloReplicationAnalysisTest, SimpleTupleSelect) {
const string module_str = R"(
HloModule SimpleTupleSelect