[XLA] Bitcasts should define values when propagating memory spaces.

Otherwise we can get inconsistent memory propagations. Consider this example:

fused_comp {
  p = s32[1]{0} parameter(0)
  ...
  ROOT b = s32[1]{0} bitcast(p)
}

fusion = s32[1]{0:S(0)} fusion(s32[1]{0:S(1)} foo), fused_computation=fused_comp

If bitcast doesn't define a value, then either fusion operand and parameter or
fusion root and output would disagree about the memory space.

PiperOrigin-RevId: 359147477
Change-Id: Ie785cdf5cc0baeabe4af0f0ec1882592c86d7254
This commit is contained in:
Berkin Ilbeyi 2021-02-23 15:23:12 -08:00 committed by TensorFlower Gardener
parent 5931dc3cb3
commit b303673e23
2 changed files with 74 additions and 6 deletions

View File

@ -19,8 +19,13 @@ namespace xla {
StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
bool modified = false;
// Configure bitcasts to define values. Otherwise, if there is only a bitcast
// between a fusion input and output and these two values are in different
// memory spaces, we can get inconsistent memory spaces between the parameter
// and fusion operand or root and fusion output.
TF_ASSIGN_OR_RETURN(auto dataflow_analysis,
HloDataflowAnalysis::Run(*module));
HloDataflowAnalysis::Run(*module, /*ssa_form=*/false,
/*bitcast_defines_value=*/true));
dataflow_analysis_ = std::move(dataflow_analysis);
for (HloComputation* computation : module->MakeNonfusionComputations()) {

View File

@ -237,7 +237,7 @@ TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
%bitcast_fusion {
%bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0)
ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param)
ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param)
}
%fused_computation {
@ -248,8 +248,8 @@ TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0)
%fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1)
%fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1)
}
ENTRY %entry {
@ -310,7 +310,7 @@ TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[6]{0:T(128)S(1)} parameter(0)
%bf_param = s32[6]{0:T(128)} parameter(0)
ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param)
}
@ -322,7 +322,7 @@ TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
}
@ -347,5 +347,68 @@ TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
EXPECT_EQ(module->Hash(), ref->Hash());
}
TEST_F(MemorySpacePropagationTest, BitcastInFusion) {
absl::string_view hlo_string = R"(
HloModule TupleOutput
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)} parameter(0)
%bitcast.0 = s32[6]{0:T(128)} bitcast(s32[6]{0:T(128)} %param_0.1)
%multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%bitcast.0, %multiply.0)
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
ROOT %fusion = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
}
)";
absl::string_view expected_hlo_string = R"(
HloModule TupleOutput
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)S(1)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
%bitcast.0 = s32[6]{0:T(128)} bitcast(s32[6]{0:T(128)S(1)} %param_0.1)
%multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%bitcast.0, %multiply.0)
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
ROOT %fusion = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_EXPECT_OK(Verify(module.get()));
TF_ASSERT_OK_AND_ASSIGN(auto ref,
ParseAndReturnVerifiedModule(expected_hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
} // namespace
} // namespace xla