diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 2f8ce62dd84..1316e8ad1aa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include #include #include #include @@ -66,6 +67,25 @@ bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) { return false; } +std::vector ExtractRelativeOrderOfNontrivialDims(const Shape& shape) { + std::vector relative_order; + for (int64 dim : LayoutUtil::MinorToMajor(shape)) { + if (shape.dimensions(dim) > 1) { + relative_order.push_back(dim); + } + } + // Now normalize the dimensions to values between 0 and true rank - 1. + std::vector sorted_dims = relative_order; + std::sort(sorted_dims.begin(), sorted_dims.end()); + for (int64& dim : relative_order) { + int64 sorted_index = std::distance( + sorted_dims.begin(), + std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim)); + dim = sorted_index; + } + return relative_order; +} + } // namespace bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, @@ -73,17 +93,20 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, std::vector params; AppendParams(producer, ¶ms); AppendParams(reduce, ¶ms); - int64 max_rank = -1; - const Layout* max_rank_layout; + int64 max_true_rank = -1; + std::vector max_rank_order; for (HloInstruction* param : params) { - if (param->shape().IsArray() && param->shape().rank() > max_rank) { - max_rank = param->shape().rank(); - max_rank_layout = ¶m->shape().layout(); + if (param->shape().IsArray() && + ShapeUtil::TrueRank(param->shape()) > max_true_rank) { + max_true_rank = ShapeUtil::TrueRank(param->shape()); + max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape()); } } return absl::c_all_of(params, [&](HloInstruction* param) { - return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) || - (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); + return !param->shape().IsArray() || + ShapeUtil::TrueRank(param->shape()) < max_true_rank || + ExtractRelativeOrderOfNontrivialDims(param->shape()) == + max_rank_order; }); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index ae31b10deb3..854aab86b8e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -91,6 +91,44 @@ TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); } +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_MixedLayoutProducerWithTrivialDim) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + mixed_input_layouts_computation { + p0.1 = f16[128,1,32,32]{1,3,2,0} parameter(0) + p1.1 = f16[128,1,32,32]{3,2,1,0} parameter(1) + copy = f16[128,1,32,32]{1,3,2,0} copy(p1.1) + c0 = f16[] constant(0) + broadcast = f16[128,1,32,32]{1,3,2,0} broadcast(c0), dimensions={} + greater-than = pred[128,1,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT + ROOT root = f16[128,1,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) + } + fused_reduce { + p0.2 = f16[128,1,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1,32,32]{1,3,2,0} parameter(0) + p1 = f16[128,1,32,32]{3,2,1,0} parameter(1) + loop_fusion = f16[128,1,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation + reduce_fusion = f32[1]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + ROOT root = (f32[1]{0}, f16[128,1,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion) + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(), + HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(1); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect); + EXPECT_TRUE( + LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); +} + TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) { auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( fused_reduce { @@ -152,17 +190,18 @@ TEST_F(GpuFusibleTest, } TEST_F(GpuFusibleTest, - LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) { + LayoutsAreReduceInputFusionFriendly_ConsiderMaximumTrueRanksParamsOnly) { auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( broadcasting_computation { p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0) - p1.1 = f32[128]{0} parameter(1) - broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0} + p1.1 = f32[1,128,1,1]{3,2,1,0} parameter(1) + reshape = f32[128]{0} reshape(p1.1) + broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(reshape), dimensions={0} ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast) } ENTRY entry { p0 = f32[128,1024,32,32]{1,3,2,0} parameter(0) - p1 = f32[128]{0} parameter(1) + p1 = f32[1,128,1,1]{3,2,1,0} parameter(1) loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation c0.2 = f32[] constant(0) ROOT reduce = f32[1024]{0} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add