[XLA] Bitcasts should define values when propagating memory spaces.
Otherwise we can get inconsistent memory propagations. Consider this example: fused_comp { p = s32[1]{0} parameter(0) ... ROOT b = s32[1]{0} bitcast(p) } fusion = s32[1]{0:S(0)} fusion(s32[1]{0:S(1)} foo), fused_computation=fused_comp If bitcast doesn't define a value, then either fusion operand and parameter or fusion root and output would disagree about the memory space. PiperOrigin-RevId: 359147477 Change-Id: Ie785cdf5cc0baeabe4af0f0ec1882592c86d7254
This commit is contained in:
parent
5931dc3cb3
commit
b303673e23
@ -19,8 +19,13 @@ namespace xla {
|
||||
|
||||
StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
|
||||
bool modified = false;
|
||||
// Configure bitcasts to define values. Otherwise, if there is only a bitcast
|
||||
// between a fusion input and output and these two values are in different
|
||||
// memory spaces, we can get inconsistent memory spaces between the parameter
|
||||
// and fusion operand or root and fusion output.
|
||||
TF_ASSIGN_OR_RETURN(auto dataflow_analysis,
|
||||
HloDataflowAnalysis::Run(*module));
|
||||
HloDataflowAnalysis::Run(*module, /*ssa_form=*/false,
|
||||
/*bitcast_defines_value=*/true));
|
||||
dataflow_analysis_ = std::move(dataflow_analysis);
|
||||
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
|
@ -237,7 +237,7 @@ TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0)
|
||||
ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param)
|
||||
ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
@ -248,8 +248,8 @@ TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0)
|
||||
%fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
|
||||
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1)
|
||||
%fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
|
||||
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
@ -310,7 +310,7 @@ TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
|
||||
HloModule NestedFusion
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[6]{0:T(128)S(1)} parameter(0)
|
||||
%bf_param = s32[6]{0:T(128)} parameter(0)
|
||||
ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
@ -322,7 +322,7 @@ TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
|
||||
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
|
||||
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
|
||||
ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
|
||||
}
|
||||
|
||||
@ -347,5 +347,68 @@ TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
TEST_F(MemorySpacePropagationTest, BitcastInFusion) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule TupleOutput
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)} parameter(0)
|
||||
%bitcast.0 = s32[6]{0:T(128)} bitcast(s32[6]{0:T(128)} %param_0.1)
|
||||
%multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%bitcast.0, %multiply.0)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
ROOT %fusion = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
}
|
||||
)";
|
||||
absl::string_view expected_hlo_string = R"(
|
||||
HloModule TupleOutput
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)S(1)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
|
||||
%bitcast.0 = s32[6]{0:T(128)} bitcast(s32[6]{0:T(128)S(1)} %param_0.1)
|
||||
%multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
|
||||
ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%bitcast.0, %multiply.0)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
ROOT %fusion = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
MemorySpacePropagation memory_space_propagation;
|
||||
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||
TF_EXPECT_OK(Verify(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ref,
|
||||
ParseAndReturnVerifiedModule(expected_hlo_string));
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user