From 76e256261ba4feee18aeb9a2e69d3d5c861ad976 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Fri, 14 Jun 2019 04:26:39 -0700 Subject: [PATCH] [XLA:CPU] Strip one-dimensions for dot by default This switches the canonical representation away from having one dimensions, allowing further optimization and bringing CPU in line with the other backends. dot->reduce simplification is still disabled when strength reduction is enabled as we don't have a fast implementation of this on CPU yet. Sadly the assumption that vector-matrix multiplication is matrix-matrix multiplication withextra 1-dimensions is baked into many places. I think I managed to fix them all and adjusted the tests. PiperOrigin-RevId: 253205266 --- .../xla/service/algebraic_simplifier.cc | 23 ++++++------ .../xla/service/cpu/cpu_instruction_fusion.cc | 23 ++++++------ .../cpu/cpu_instruction_fusion_test.cc | 27 ++++++++------ .../service/cpu/cpu_layout_assignment_test.cc | 32 +++++++++++------ .../xla/service/cpu/dot_op_emitter.cc | 35 +++++++++++++------ .../compiler/xla/service/hlo_instruction.cc | 23 ++++++------ .../compiler/xla/service/hlo_instruction.h | 4 +-- .../compiler/xla/tests/dot_operation_test.cc | 5 +++ tensorflow/compiler/xla/tests/test_utils.cc | 7 ++-- 9 files changed, 106 insertions(+), 73 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 9d3707656c9..697c9313073 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1783,12 +1783,13 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // If the lhs or rhs have only batch and contracting dimensions, a dot can be // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) - if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + - dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == - lhs->shape().rank()) || - (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + - dot->dot_dimension_numbers().rhs_batch_dimensions_size() == - rhs->shape().rank())) { + if (options_.enable_dot_strength_reduction() && + ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == + lhs->shape().rank()) || + (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size() == + rhs->shape().rank()))) { TF_ASSIGN_OR_RETURN( HloInstruction * new_lhs, NormalizeDotOperandToBatchMajorAndContractingMinor( @@ -1884,12 +1885,10 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_gather_optimized); } - if (options_.enable_dot_strength_reduction()) { - TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions, - RemoveDegenerateDimensionFromDot(dot)); - if (removed_degenerate_dimensions) { - return Status::OK(); - } + TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions, + RemoveDegenerateDimensionFromDot(dot)); + if (removed_degenerate_dimensions) { + return Status::OK(); } // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 3d7c06cfa27..a37b5e9e51b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -22,11 +22,6 @@ namespace cpu { namespace { -int64 BytesInDimension(const Shape& shape, int64 dimension) { - return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) * - shape.dimensions(dimension); -} - bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. @@ -47,8 +42,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) { bool IsNonComplexMatrixVectorDot(const HloInstruction* hlo) { const Shape& hlo_shape = hlo->shape(); return !ShapeUtil::ElementIsComplex(hlo_shape) && - hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 && - (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); + hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1; } bool HasExactlyOneUse(const HloInstruction& hlo_instr) { @@ -136,18 +130,21 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // fusion can easily be overshadowed by the overhead of a naive GEMM // algorithm in the IR. const Shape& output_shape = consumer->shape(); - if (output_shape.dimensions_size() == 2) { - // We fuse in cases where we have dot([A,B],[B,1]) or dot([1,A],[A,B]) and + if (output_shape.dimensions_size() <= 1) { + // We fuse in cases where we have a matrix*vector or vector*matrix dot and // fusion can get rid of the larger tensor. We assume that a naive // traversal of a small enough (to fit in L1) column or row tensor is // "good enough" from the perspective of cache management; and calling out // to an optimized GEMM kernel is not a huge win. - if (output_shape.dimensions(0) == 1 && operand_index == 1 && - BytesInDimension(output_shape, 1) < kFusionThresholdBytes) { + if (consumer->operand(0)->shape().rank() == 1 && operand_index == 1 && + ShapeUtil::ByteSizeOfElements(consumer->operand(0)->shape()) < + kFusionThresholdBytes) { VLOG(2) << "Fusing small matrix-vector product."; return true; - } else if (output_shape.dimensions(1) == 1 && operand_index == 0 && - BytesInDimension(output_shape, 0) < kFusionThresholdBytes) { + } else if (consumer->operand(1)->shape().rank() == 1 && + operand_index == 0 && + ShapeUtil::ByteSizeOfElements(consumer->operand(1)->shape()) < + kFusionThresholdBytes) { VLOG(2) << "Fusing small matrix-vector product."; return true; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index b35026a41cd..e1f5ab11456 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -51,12 +51,12 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); + 1, ShapeUtil::MakeShape(F32, {256}), "arg1")); HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); + MakeDot(ShapeUtil::MakeShape(F32, {1024}), exp0, arg1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -68,14 +68,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); + 0, ShapeUtil::MakeShape(F32, {256}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1")); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); + MakeDot(ShapeUtil::MakeShape(F32, {1024}), arg0, exp1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -89,14 +89,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); + 1, ShapeUtil::MakeShape(F32, {256}), "arg1")); HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); + MakeDot(ShapeUtil::MakeShape(F32, {1024}), bitcast0, arg1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -110,7 +110,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); + 1, ShapeUtil::MakeShape(F32, {256}), "arg1")); HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); @@ -118,7 +118,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); + MakeDot(ShapeUtil::MakeShape(F32, {1024}), reshape0, arg1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -130,14 +130,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 32 * 1024}), "arg0")); + 0, ShapeUtil::MakeShape(F32, {32 * 1024}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1")); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); + MakeDot(ShapeUtil::MakeShape(F32, {32 * 1024}), arg0, exp1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -637,6 +637,13 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + if (m == 1) { + dot_lhs_shape = ShapeUtil::MakeShape(F32, {k}); + dot_shape = ShapeUtil::MakeShape(F32, {n}); + } else if (n == 1) { + dot_rhs_shape = ShapeUtil::MakeShape(F32, {k}); + dot_shape = ShapeUtil::MakeShape(F32, {m}); + } auto* dot_lhs = builder.AddInstruction( HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 7d6682c995d..8b92d67c0cb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -63,9 +63,9 @@ class CpuLayoutAssignmentTest : public HloTestBase { TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto builder = HloComputation::Builder(TestName()); - Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); + Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); - Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); + Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0}); auto dot_lhs = builder.AddInstruction( HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( @@ -83,12 +83,12 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); AssignLayouts(module.get(), &computation_layout); - EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}), dot_lhs->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), dot_rhs->shape().layout())); - EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), - result->shape().layout())); + EXPECT_TRUE( + LayoutUtil::Equal(LayoutUtil::MakeLayout({0}), result->shape().layout())); for (const auto& instruction : computation->instructions()) { EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); } @@ -98,9 +98,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { // Two dot products have the same constant as the RHS, and both those dot // products can be optimized if the constant has a column-major layout. auto builder = HloComputation::Builder(TestName()); - Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); + Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); - Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); + Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0}); auto dot_a_lhs = builder.AddInstruction( HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_b_lhs = builder.AddInstruction( @@ -128,7 +128,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { dot_rhs->shape().layout())); for (HloInstruction* instruction : {dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) { - EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}), instruction->shape().layout())); } for (const auto& instruction : computation->instructions()) { @@ -270,6 +270,13 @@ static StatusOr RunDotOutputFusion( Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + if (m == 1) { + dot_lhs_shape = ShapeUtil::MakeShape(F32, {k}); + dot_shape = ShapeUtil::MakeShape(F32, {n}); + } else if (n == 1) { + dot_rhs_shape = ShapeUtil::MakeShape(F32, {k}); + dot_shape = ShapeUtil::MakeShape(F32, {m}); + } HloInstruction* dot_lhs = builder.AddInstruction( HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); @@ -338,16 +345,21 @@ static void AssertCorrectLayoutForDotOutputFusion( Layout expected_dot_rhs_layout = expect_col_major_dot_rhs ? LayoutUtil::MakeLayout({0, 1}) : LayoutUtil::MakeLayout({1, 0}); + if (layout_assignment_result.dot_rhs_fusion_param->shape().rank() == 1) { + expected_dot_rhs_layout = LayoutUtil::MakeLayout({0}); + } EXPECT_TRUE(LayoutUtil::Equal( expected_dot_rhs_layout, layout_assignment_result.dot_rhs_fusion_param->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( - LayoutUtil::MakeLayout({1, 0}), + LayoutUtil::MakeDescendingLayout( + layout_assignment_result.dot_lhs_fusion_param->shape().rank()), layout_assignment_result.dot_lhs_fusion_param->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( - LayoutUtil::MakeLayout({1, 0}), + LayoutUtil::MakeDescendingLayout( + layout_assignment_result.addend_fusion_param->shape().rank()), layout_assignment_result.addend_fusion_param->shape().layout())); EXPECT_THAT(computation->instructions(), Each(Not(op::Copy()))); } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index a421ea7d205..eb7d25609ff 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -702,19 +702,32 @@ Status DotOpEmitter::EmitCallToRuntime() { } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { - CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2); + CHECK_LE(dot_info_.result_shape.dimensions_size(), 2); const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; + auto is_column_major = [](const Shape& shape) { + return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0; + }; + + // Non-contracting dots should never make it here. + CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0); + CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0); + return { - /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), + /*m=*/lhs_shape.rank() <= 1 + ? 1LL + : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)), /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)), - /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)), - /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, - /*lhs_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 1, - /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, + /*n=*/rhs_shape.rank() <= 1 + ? 1LL + : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)), + /*lhs_column_major=*/is_column_major(lhs_shape), + /*lhs_canonical=*/lhs_shape.rank() <= 1 || + dim_nums.lhs_contracting_dimensions(0) == 1, + /*rhs_column_major=*/is_column_major(rhs_shape), /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0}; } @@ -722,8 +735,7 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { // column major. absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo) { - if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && - hlo.shape().dimensions(0) == 1) { + if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) { if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) { return 1; } @@ -842,9 +854,10 @@ DotImplementationStrategy GetDotImplementationStrategy( // Any Matrix-Vector product of floating point or integral type, or // a transpose-dot fusion of the same can be lowered to a tiled LLVM // IR implementation. - if (dot_info.result_shape.dimensions_size() == 2 && - (dot_info.result_shape.dimensions(0) == 1 || - dot_info.result_shape.dimensions(1) == 1) && + if ((dot_info.result_shape.dimensions_size() <= 1 || + (dot_info.result_shape.dimensions_size() == 2 && + (dot_info.result_shape.dimensions(0) == 1 || + dot_info.result_shape.dimensions(1) == 1))) && (primitive_util::IsFloatingPointType(element_type) || primitive_util::IsIntegralType(element_type))) { return DotImplementationStrategy::kTiledLlvmIrGemv; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f11a4f66963..00292eb90cf 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -3019,7 +3019,8 @@ class HloInstruction::FusionReusesParamElements { } }; -HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { +HloInstruction::UseKind HloInstruction::OperandElementUse( + int64 operand_num) const { switch (opcode_) { case HloOpcode::kBitcast: case HloOpcode::kConcatenate: @@ -3030,30 +3031,28 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { return UseKind::kUsePermutingElements; case HloOpcode::kPad: // Pad reuses the padding value but not the padded array elements. - return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; + return operand_num > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; case HloOpcode::kReduce: // Reduce reuses the init values but not the operand array elements. - return i >= Cast(this)->input_count() + return operand_num >= Cast(this)->input_count() ? UseKind::kReuse : UseKind::kUsePermutingElements; case HloOpcode::kFusion: // Uses the memoizing, recursive computation defined above. - return FusionReusesParamElements::Compute(i, *fused_expression_root()); + return FusionReusesParamElements::Compute(operand_num, + *fused_expression_root()); case HloOpcode::kDot: - // Dot operations with inputs [A,B] * [B,1] do not re-use - // elements on their left operand. - // Dot operations with inputs [1,A] * [A,B] do not re-use - // elements on their right operand. - if (shape().dimensions_size() == 2) { - if ((i == 0 && shape().dimensions(1) == 1) || - (i == 1 && shape().dimensions(0) == 1)) { + // Matrix-vector dots do not reuse the matrix operand. + if (shape().dimensions_size() <= 1) { + if ((operand_num == 0 && operand(1)->shape().rank() <= 1) || + (operand_num == 1 && operand(0)->shape().rank() <= 1)) { return UseKind::kUse; } } return UseKind::kReuse; case HloOpcode::kDynamicUpdateSlice: // Dynamic-update-slice reuses only start_indices. - if (i == 0 || i == 1) { + if (operand_num == 0 || operand_num == 1) { return UseKind::kUse; } return UseKind::kReuse; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 450c48d988b..64e97ad5e46 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1768,8 +1768,8 @@ class HloInstruction { // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Returns how this instruction uses elements of its `i`th operand. - UseKind OperandElementUse(int64 i) const; + // Returns how this instruction uses elements of its operand at operand_num. + UseKind OperandElementUse(int64 operand_num) const; // Helper for implementing backend_config(). Parses backend_config_ into the // given proto. diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 59c3d4f5c7e..25e82842b05 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -374,6 +374,11 @@ std::vector CreateDotTestParameters() { } }; + add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42); + add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42); + add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1); + add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1); + add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1); add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 07dabc2cfaf..a71dd5abaec 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -585,13 +585,14 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, std::unique_ptr CreateCanonicalDot(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(lhs->shape().rank(), 2); - CHECK_EQ(rhs->shape().rank(), 2); + CHECK_LE(lhs->shape().rank(), 2); + CHECK_LE(rhs->shape().rank(), 2); PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( 2, PrecisionConfig::DEFAULT); DotDimensionNumbers dot_dimension_numbers; - dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_lhs_contracting_dimensions( + lhs->shape().rank() > 1 ? 1 : 0); dot_dimension_numbers.add_rhs_contracting_dimensions(0); return absl::make_unique( shape, lhs, rhs, dot_dimension_numbers, precision_config);