From f45c45d90655e34af1b4a5a52dc5c1a0bbb2ac87 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Jun 2019 04:51:35 -0700 Subject: [PATCH] Automated rollback of commit 76e256261ba4feee18aeb9a2e69d3d5c861ad976 PiperOrigin-RevId: 253207758 --- .../xla/service/algebraic_simplifier.cc | 23 ++++++------ .../xla/service/cpu/cpu_instruction_fusion.cc | 23 ++++++------ .../cpu/cpu_instruction_fusion_test.cc | 27 ++++++-------- .../service/cpu/cpu_layout_assignment_test.cc | 32 ++++++----------- .../xla/service/cpu/dot_op_emitter.cc | 35 ++++++------------- .../compiler/xla/service/hlo_instruction.cc | 23 ++++++------ .../compiler/xla/service/hlo_instruction.h | 4 +-- .../compiler/xla/tests/dot_operation_test.cc | 5 --- tensorflow/compiler/xla/tests/test_utils.cc | 7 ++-- 9 files changed, 73 insertions(+), 106 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 697c9313073..9d3707656c9 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1783,13 +1783,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // If the lhs or rhs have only batch and contracting dimensions, a dot can be // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) - if (options_.enable_dot_strength_reduction() && - ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + - dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == - lhs->shape().rank()) || - (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + - dot->dot_dimension_numbers().rhs_batch_dimensions_size() == - rhs->shape().rank()))) { + if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == + lhs->shape().rank()) || + (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size() == + rhs->shape().rank())) { TF_ASSIGN_OR_RETURN( HloInstruction * new_lhs, NormalizeDotOperandToBatchMajorAndContractingMinor( @@ -1885,10 +1884,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_gather_optimized); } - TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions, - RemoveDegenerateDimensionFromDot(dot)); - if (removed_degenerate_dimensions) { - return Status::OK(); + if (options_.enable_dot_strength_reduction()) { + TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions, + RemoveDegenerateDimensionFromDot(dot)); + if (removed_degenerate_dimensions) { + return Status::OK(); + } } // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index a37b5e9e51b..3d7c06cfa27 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -22,6 +22,11 @@ namespace cpu { namespace { +int64 BytesInDimension(const Shape& shape, int64 dimension) { + return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) * + shape.dimensions(dimension); +} + bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. @@ -42,7 +47,8 @@ bool CanBeLoopFused(const HloInstruction& hlo) { bool IsNonComplexMatrixVectorDot(const HloInstruction* hlo) { const Shape& hlo_shape = hlo->shape(); return !ShapeUtil::ElementIsComplex(hlo_shape) && - hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1; + hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 && + (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); } bool HasExactlyOneUse(const HloInstruction& hlo_instr) { @@ -130,21 +136,18 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // fusion can easily be overshadowed by the overhead of a naive GEMM // algorithm in the IR. const Shape& output_shape = consumer->shape(); - if (output_shape.dimensions_size() <= 1) { - // We fuse in cases where we have a matrix*vector or vector*matrix dot and + if (output_shape.dimensions_size() == 2) { + // We fuse in cases where we have dot([A,B],[B,1]) or dot([1,A],[A,B]) and // fusion can get rid of the larger tensor. We assume that a naive // traversal of a small enough (to fit in L1) column or row tensor is // "good enough" from the perspective of cache management; and calling out // to an optimized GEMM kernel is not a huge win. - if (consumer->operand(0)->shape().rank() == 1 && operand_index == 1 && - ShapeUtil::ByteSizeOfElements(consumer->operand(0)->shape()) < - kFusionThresholdBytes) { + if (output_shape.dimensions(0) == 1 && operand_index == 1 && + BytesInDimension(output_shape, 1) < kFusionThresholdBytes) { VLOG(2) << "Fusing small matrix-vector product."; return true; - } else if (consumer->operand(1)->shape().rank() == 1 && - operand_index == 0 && - ShapeUtil::ByteSizeOfElements(consumer->operand(1)->shape()) < - kFusionThresholdBytes) { + } else if (output_shape.dimensions(1) == 1 && operand_index == 0 && + BytesInDimension(output_shape, 0) < kFusionThresholdBytes) { VLOG(2) << "Fusing small matrix-vector product."; return true; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index e1f5ab11456..b35026a41cd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -51,12 +51,12 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {256}), "arg1")); + 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1024}), exp0, arg1)); + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -68,14 +68,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {256}), "arg0")); + 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1")); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1024}), arg0, exp1)); + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -89,14 +89,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {256}), "arg1")); + 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1024}), bitcast0, arg1)); + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -110,7 +110,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {256}), "arg1")); + 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); @@ -118,7 +118,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1024}), reshape0, arg1)); + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -130,14 +130,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {32 * 1024}), "arg0")); + 0, ShapeUtil::MakeShape(F32, {1, 32 * 1024}), "arg0")); HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1")); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); HloInstruction* dot = builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {32 * 1024}), arg0, exp1)); + MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -637,13 +637,6 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); - if (m == 1) { - dot_lhs_shape = ShapeUtil::MakeShape(F32, {k}); - dot_shape = ShapeUtil::MakeShape(F32, {n}); - } else if (n == 1) { - dot_rhs_shape = ShapeUtil::MakeShape(F32, {k}); - dot_shape = ShapeUtil::MakeShape(F32, {m}); - } auto* dot_lhs = builder.AddInstruction( HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 8b92d67c0cb..7d6682c995d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -63,9 +63,9 @@ class CpuLayoutAssignmentTest : public HloTestBase { TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto builder = HloComputation::Builder(TestName()); - Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0}); + Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); - Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0}); + Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); auto dot_lhs = builder.AddInstruction( HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( @@ -83,12 +83,12 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); AssignLayouts(module.get(), &computation_layout); - EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}), + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), dot_lhs->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), dot_rhs->shape().layout())); - EXPECT_TRUE( - LayoutUtil::Equal(LayoutUtil::MakeLayout({0}), result->shape().layout())); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), + result->shape().layout())); for (const auto& instruction : computation->instructions()) { EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); } @@ -98,9 +98,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { // Two dot products have the same constant as the RHS, and both those dot // products can be optimized if the constant has a column-major layout. auto builder = HloComputation::Builder(TestName()); - Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0}); + Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); - Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0}); + Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); auto dot_a_lhs = builder.AddInstruction( HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_b_lhs = builder.AddInstruction( @@ -128,7 +128,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { dot_rhs->shape().layout())); for (HloInstruction* instruction : {dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) { - EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}), + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), instruction->shape().layout())); } for (const auto& instruction : computation->instructions()) { @@ -270,13 +270,6 @@ static StatusOr RunDotOutputFusion( Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); - if (m == 1) { - dot_lhs_shape = ShapeUtil::MakeShape(F32, {k}); - dot_shape = ShapeUtil::MakeShape(F32, {n}); - } else if (n == 1) { - dot_rhs_shape = ShapeUtil::MakeShape(F32, {k}); - dot_shape = ShapeUtil::MakeShape(F32, {m}); - } HloInstruction* dot_lhs = builder.AddInstruction( HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); @@ -345,21 +338,16 @@ static void AssertCorrectLayoutForDotOutputFusion( Layout expected_dot_rhs_layout = expect_col_major_dot_rhs ? LayoutUtil::MakeLayout({0, 1}) : LayoutUtil::MakeLayout({1, 0}); - if (layout_assignment_result.dot_rhs_fusion_param->shape().rank() == 1) { - expected_dot_rhs_layout = LayoutUtil::MakeLayout({0}); - } EXPECT_TRUE(LayoutUtil::Equal( expected_dot_rhs_layout, layout_assignment_result.dot_rhs_fusion_param->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( - LayoutUtil::MakeDescendingLayout( - layout_assignment_result.dot_lhs_fusion_param->shape().rank()), + LayoutUtil::MakeLayout({1, 0}), layout_assignment_result.dot_lhs_fusion_param->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( - LayoutUtil::MakeDescendingLayout( - layout_assignment_result.addend_fusion_param->shape().rank()), + LayoutUtil::MakeLayout({1, 0}), layout_assignment_result.addend_fusion_param->shape().layout())); EXPECT_THAT(computation->instructions(), Each(Not(op::Copy()))); } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index eb7d25609ff..a421ea7d205 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -702,32 +702,19 @@ Status DotOpEmitter::EmitCallToRuntime() { } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { - CHECK_LE(dot_info_.result_shape.dimensions_size(), 2); + CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2); const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; - auto is_column_major = [](const Shape& shape) { - return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0; - }; - - // Non-contracting dots should never make it here. - CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0); - CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0); - return { - /*m=*/lhs_shape.rank() <= 1 - ? 1LL - : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)), + /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)), - /*n=*/rhs_shape.rank() <= 1 - ? 1LL - : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)), - /*lhs_column_major=*/is_column_major(lhs_shape), - /*lhs_canonical=*/lhs_shape.rank() <= 1 || - dim_nums.lhs_contracting_dimensions(0) == 1, - /*rhs_column_major=*/is_column_major(rhs_shape), + /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)), + /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, + /*lhs_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 1, + /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0}; } @@ -735,7 +722,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { // column major. absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo) { - if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) { + if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && + hlo.shape().dimensions(0) == 1) { if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) { return 1; } @@ -854,10 +842,9 @@ DotImplementationStrategy GetDotImplementationStrategy( // Any Matrix-Vector product of floating point or integral type, or // a transpose-dot fusion of the same can be lowered to a tiled LLVM // IR implementation. - if ((dot_info.result_shape.dimensions_size() <= 1 || - (dot_info.result_shape.dimensions_size() == 2 && - (dot_info.result_shape.dimensions(0) == 1 || - dot_info.result_shape.dimensions(1) == 1))) && + if (dot_info.result_shape.dimensions_size() == 2 && + (dot_info.result_shape.dimensions(0) == 1 || + dot_info.result_shape.dimensions(1) == 1) && (primitive_util::IsFloatingPointType(element_type) || primitive_util::IsIntegralType(element_type))) { return DotImplementationStrategy::kTiledLlvmIrGemv; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 00292eb90cf..f11a4f66963 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -3019,8 +3019,7 @@ class HloInstruction::FusionReusesParamElements { } }; -HloInstruction::UseKind HloInstruction::OperandElementUse( - int64 operand_num) const { +HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { switch (opcode_) { case HloOpcode::kBitcast: case HloOpcode::kConcatenate: @@ -3031,28 +3030,30 @@ HloInstruction::UseKind HloInstruction::OperandElementUse( return UseKind::kUsePermutingElements; case HloOpcode::kPad: // Pad reuses the padding value but not the padded array elements. - return operand_num > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; + return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; case HloOpcode::kReduce: // Reduce reuses the init values but not the operand array elements. - return operand_num >= Cast(this)->input_count() + return i >= Cast(this)->input_count() ? UseKind::kReuse : UseKind::kUsePermutingElements; case HloOpcode::kFusion: // Uses the memoizing, recursive computation defined above. - return FusionReusesParamElements::Compute(operand_num, - *fused_expression_root()); + return FusionReusesParamElements::Compute(i, *fused_expression_root()); case HloOpcode::kDot: - // Matrix-vector dots do not reuse the matrix operand. - if (shape().dimensions_size() <= 1) { - if ((operand_num == 0 && operand(1)->shape().rank() <= 1) || - (operand_num == 1 && operand(0)->shape().rank() <= 1)) { + // Dot operations with inputs [A,B] * [B,1] do not re-use + // elements on their left operand. + // Dot operations with inputs [1,A] * [A,B] do not re-use + // elements on their right operand. + if (shape().dimensions_size() == 2) { + if ((i == 0 && shape().dimensions(1) == 1) || + (i == 1 && shape().dimensions(0) == 1)) { return UseKind::kUse; } } return UseKind::kReuse; case HloOpcode::kDynamicUpdateSlice: // Dynamic-update-slice reuses only start_indices. - if (operand_num == 0 || operand_num == 1) { + if (i == 0 || i == 1) { return UseKind::kUse; } return UseKind::kReuse; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 64e97ad5e46..450c48d988b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1768,8 +1768,8 @@ class HloInstruction { // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Returns how this instruction uses elements of its operand at operand_num. - UseKind OperandElementUse(int64 operand_num) const; + // Returns how this instruction uses elements of its `i`th operand. + UseKind OperandElementUse(int64 i) const; // Helper for implementing backend_config(). Parses backend_config_ into the // given proto. diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 25e82842b05..59c3d4f5c7e 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -374,11 +374,6 @@ std::vector CreateDotTestParameters() { } }; - add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42); - add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42); - add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1); - add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1); - add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1); add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index a71dd5abaec..07dabc2cfaf 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -585,14 +585,13 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, std::unique_ptr CreateCanonicalDot(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_LE(lhs->shape().rank(), 2); - CHECK_LE(rhs->shape().rank(), 2); + CHECK_EQ(lhs->shape().rank(), 2); + CHECK_EQ(rhs->shape().rank(), 2); PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( 2, PrecisionConfig::DEFAULT); DotDimensionNumbers dot_dimension_numbers; - dot_dimension_numbers.add_lhs_contracting_dimensions( - lhs->shape().rank() > 1 ? 1 : 0); + dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); return absl::make_unique( shape, lhs, rhs, dot_dimension_numbers, precision_config);