From 18f2dd4262b3a5a8c3ab7b5f7b475aeabcfc8a6b Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Mon, 8 Feb 2021 16:08:37 -0800 Subject: [PATCH] Allow more sharding propagation from broadcast operands MaybeImproveInstructionSharding() already checks that it needs to be compatible with existing tiled sharding (if any). PiperOrigin-RevId: 356375771 Change-Id: Ic5e65e00cea0a25cec2cf93027632c4c7783a20e --- .../xla/service/sharding_propagation.cc | 5 --- .../xla/service/sharding_propagation_test.cc | 31 +++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index c7ece50d202..49caa2ee347 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -739,11 +739,6 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (aggressiveness < 3) { return false; } - // Do not override existing tile sharding. This is likely from users. - if (IsSpatiallyPartitioned(instruction) && - !instruction->sharding().IsTileMaximal()) { - return false; - } const HloInstruction* op = instruction->operand(0); if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) { return false; diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index eed6de9b5b7..f91ecec8b9f 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -392,6 +392,37 @@ ENTRY %broadcast { } } +TEST_P(ParameterizedMetadataTest, BroadcastMerge) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[3,2048]parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate metadata={op_name="a"}} + %broadcast = f32[3,2048,3] broadcast(%param0), dimensions={0,1} + ROOT %copy = f32[3,2048,3] copy(%broadcast), + sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate metadata={op_name="b"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + if (GetParam().clear_metadata) { + ClearMetadata(module.get()); + } + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata) + .Run(module.get())); + EXPECT_TRUE(changed); + auto* instruction = FindInstruction(module.get(), "broadcast"); + ASSERT_NE(instruction, nullptr); + EXPECT_THAT(instruction, op::Sharding("{devices=[1,2,2]0,1,2,3}")); + if (GetParam().propagate_metadata && !GetParam().clear_metadata) { + EXPECT_THAT(instruction->sharding(), + ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); + } else { + EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); + } +} + TEST_P(ParameterizedMetadataTest, BroadcastUser) { const char* const hlo_string = R"( HloModule module