From b303673e23a10221afe9158efdf4c61998ca3a65 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Tue, 23 Feb 2021 15:23:12 -0800 Subject: [PATCH] [XLA] Bitcasts should define values when propagating memory spaces. Otherwise we can get inconsistent memory propagations. Consider this example: fused_comp { p = s32[1]{0} parameter(0) ... ROOT b = s32[1]{0} bitcast(p) } fusion = s32[1]{0:S(0)} fusion(s32[1]{0:S(1)} foo), fused_computation=fused_comp If bitcast doesn't define a value, then either fusion operand and parameter or fusion root and output would disagree about the memory space. PiperOrigin-RevId: 359147477 Change-Id: Ie785cdf5cc0baeabe4af0f0ec1882592c86d7254 --- .../xla/service/memory_space_propagation.cc | 7 +- .../service/memory_space_propagation_test.cc | 73 +++++++++++++++++-- 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.cc b/tensorflow/compiler/xla/service/memory_space_propagation.cc index 2eb15b14eaf..949e4b94e39 100644 --- a/tensorflow/compiler/xla/service/memory_space_propagation.cc +++ b/tensorflow/compiler/xla/service/memory_space_propagation.cc @@ -19,8 +19,13 @@ namespace xla { StatusOr MemorySpacePropagation::Run(HloModule* module) { bool modified = false; + // Configure bitcasts to define values. Otherwise, if there is only a bitcast + // between a fusion input and output and these two values are in different + // memory spaces, we can get inconsistent memory spaces between the parameter + // and fusion operand or root and fusion output. TF_ASSIGN_OR_RETURN(auto dataflow_analysis, - HloDataflowAnalysis::Run(*module)); + HloDataflowAnalysis::Run(*module, /*ssa_form=*/false, + /*bitcast_defines_value=*/true)); dataflow_analysis_ = std::move(dataflow_analysis); for (HloComputation* computation : module->MakeNonfusionComputations()) { diff --git a/tensorflow/compiler/xla/service/memory_space_propagation_test.cc b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc index de45af5a190..5beaef46387 100644 --- a/tensorflow/compiler/xla/service/memory_space_propagation_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc @@ -237,7 +237,7 @@ TEST_F(MemorySpacePropagationTest, NestedInputFusion) { %bitcast_fusion { %bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0) - ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param) + ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param) } %fused_computation { @@ -248,8 +248,8 @@ TEST_F(MemorySpacePropagationTest, NestedInputFusion) { %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) %param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0) - %fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion - ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1) + %fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion + ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1) } ENTRY %entry { @@ -310,7 +310,7 @@ TEST_F(MemorySpacePropagationTest, NestedOutputFusion) { HloModule NestedFusion %bitcast_fusion { - %bf_param = s32[6]{0:T(128)S(1)} parameter(0) + %bf_param = s32[6]{0:T(128)} parameter(0) ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param) } @@ -322,7 +322,7 @@ TEST_F(MemorySpacePropagationTest, NestedOutputFusion) { %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) - %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1) + %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1) ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion } @@ -347,5 +347,68 @@ TEST_F(MemorySpacePropagationTest, NestedOutputFusion) { EXPECT_EQ(module->Hash(), ref->Hash()); } +TEST_F(MemorySpacePropagationTest, BitcastInFusion) { + absl::string_view hlo_string = R"( + HloModule TupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + %bitcast.0 = s32[6]{0:T(128)} bitcast(s32[6]{0:T(128)} %param_0.1) + %multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%bitcast.0, %multiply.0) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + ROOT %fusion = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule TupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)S(1)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) + %bitcast.0 = s32[6]{0:T(128)} bitcast(s32[6]{0:T(128)S(1)} %param_0.1) + %multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1) + ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%bitcast.0, %multiply.0) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + ROOT %fusion = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + } // namespace } // namespace xla