Automated rollback of commit 76e256261ba4feee18aeb9a2e69d3d5c861ad976
PiperOrigin-RevId: 253207758
This commit is contained in:
parent
76e256261b
commit
f45c45d906
@ -1783,13 +1783,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
||||
|
||||
// If the lhs or rhs have only batch and contracting dimensions, a dot can be
|
||||
// rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
|
||||
if (options_.enable_dot_strength_reduction() &&
|
||||
((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
|
||||
dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
|
||||
lhs->shape().rank()) ||
|
||||
(dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
|
||||
dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
|
||||
rhs->shape().rank()))) {
|
||||
if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
|
||||
dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
|
||||
lhs->shape().rank()) ||
|
||||
(dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
|
||||
dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
|
||||
rhs->shape().rank())) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_lhs,
|
||||
NormalizeDotOperandToBatchMajorAndContractingMinor(
|
||||
@ -1885,10 +1884,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
||||
return ReplaceInstruction(dot, dot_of_gather_optimized);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
|
||||
RemoveDegenerateDimensionFromDot(dot));
|
||||
if (removed_degenerate_dimensions) {
|
||||
return Status::OK();
|
||||
if (options_.enable_dot_strength_reduction()) {
|
||||
TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
|
||||
RemoveDegenerateDimensionFromDot(dot));
|
||||
if (removed_degenerate_dimensions) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
|
||||
|
@ -22,6 +22,11 @@ namespace cpu {
|
||||
|
||||
namespace {
|
||||
|
||||
int64 BytesInDimension(const Shape& shape, int64 dimension) {
|
||||
return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
|
||||
shape.dimensions(dimension);
|
||||
}
|
||||
|
||||
bool CanBeLoopFused(const HloInstruction& hlo) {
|
||||
// These are the only ones we fuse since we rely on effective elemental IR
|
||||
// generation.
|
||||
@ -42,7 +47,8 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
|
||||
bool IsNonComplexMatrixVectorDot(const HloInstruction* hlo) {
|
||||
const Shape& hlo_shape = hlo->shape();
|
||||
return !ShapeUtil::ElementIsComplex(hlo_shape) &&
|
||||
hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1;
|
||||
hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 &&
|
||||
(hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1);
|
||||
}
|
||||
|
||||
bool HasExactlyOneUse(const HloInstruction& hlo_instr) {
|
||||
@ -130,21 +136,18 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
// fusion can easily be overshadowed by the overhead of a naive GEMM
|
||||
// algorithm in the IR.
|
||||
const Shape& output_shape = consumer->shape();
|
||||
if (output_shape.dimensions_size() <= 1) {
|
||||
// We fuse in cases where we have a matrix*vector or vector*matrix dot and
|
||||
if (output_shape.dimensions_size() == 2) {
|
||||
// We fuse in cases where we have dot([A,B],[B,1]) or dot([1,A],[A,B]) and
|
||||
// fusion can get rid of the larger tensor. We assume that a naive
|
||||
// traversal of a small enough (to fit in L1) column or row tensor is
|
||||
// "good enough" from the perspective of cache management; and calling out
|
||||
// to an optimized GEMM kernel is not a huge win.
|
||||
if (consumer->operand(0)->shape().rank() == 1 && operand_index == 1 &&
|
||||
ShapeUtil::ByteSizeOfElements(consumer->operand(0)->shape()) <
|
||||
kFusionThresholdBytes) {
|
||||
if (output_shape.dimensions(0) == 1 && operand_index == 1 &&
|
||||
BytesInDimension(output_shape, 1) < kFusionThresholdBytes) {
|
||||
VLOG(2) << "Fusing small matrix-vector product.";
|
||||
return true;
|
||||
} else if (consumer->operand(1)->shape().rank() == 1 &&
|
||||
operand_index == 0 &&
|
||||
ShapeUtil::ByteSizeOfElements(consumer->operand(1)->shape()) <
|
||||
kFusionThresholdBytes) {
|
||||
} else if (output_shape.dimensions(1) == 1 && operand_index == 0 &&
|
||||
BytesInDimension(output_shape, 0) < kFusionThresholdBytes) {
|
||||
VLOG(2) << "Fusing small matrix-vector product.";
|
||||
return true;
|
||||
}
|
||||
|
@ -51,12 +51,12 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
|
||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0"));
|
||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
|
||||
1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
|
||||
|
||||
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0));
|
||||
HloInstruction* dot = builder.AddInstruction(
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {1024}), exp0, arg1));
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1));
|
||||
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
@ -68,14 +68,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
|
||||
TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {256}), "arg0"));
|
||||
0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0"));
|
||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
|
||||
|
||||
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
|
||||
HloInstruction* dot = builder.AddInstruction(
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {1024}), arg0, exp1));
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1));
|
||||
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
@ -89,14 +89,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
|
||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
|
||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
|
||||
1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
|
||||
|
||||
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
|
||||
HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0));
|
||||
HloInstruction* dot = builder.AddInstruction(
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {1024}), bitcast0, arg1));
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1));
|
||||
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
@ -110,7 +110,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
|
||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
|
||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
|
||||
1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
|
||||
|
||||
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
|
||||
@ -118,7 +118,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
|
||||
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(S32, {1024, 256}), exp0));
|
||||
HloInstruction* dot = builder.AddInstruction(
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {1024}), reshape0, arg1));
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1));
|
||||
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
@ -130,14 +130,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
|
||||
TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {32 * 1024}), "arg0"));
|
||||
0, ShapeUtil::MakeShape(F32, {1, 32 * 1024}), "arg0"));
|
||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1"));
|
||||
|
||||
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1));
|
||||
HloInstruction* dot = builder.AddInstruction(
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {32 * 1024}), arg0, exp1));
|
||||
MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1));
|
||||
|
||||
auto module = CreateNewUnverifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
@ -637,13 +637,6 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name,
|
||||
Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
|
||||
Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
|
||||
Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
|
||||
if (m == 1) {
|
||||
dot_lhs_shape = ShapeUtil::MakeShape(F32, {k});
|
||||
dot_shape = ShapeUtil::MakeShape(F32, {n});
|
||||
} else if (n == 1) {
|
||||
dot_rhs_shape = ShapeUtil::MakeShape(F32, {k});
|
||||
dot_shape = ShapeUtil::MakeShape(F32, {m});
|
||||
}
|
||||
|
||||
auto* dot_lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
|
||||
|
@ -63,9 +63,9 @@ class CpuLayoutAssignmentTest : public HloTestBase {
|
||||
|
||||
TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0});
|
||||
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
|
||||
Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
|
||||
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0});
|
||||
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
|
||||
auto dot_lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, lhs_shape, "param0"));
|
||||
auto dot_rhs = builder.AddInstruction(
|
||||
@ -83,12 +83,12 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
|
||||
ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
|
||||
AssignLayouts(module.get(), &computation_layout);
|
||||
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}),
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
|
||||
dot_lhs->shape().layout()));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}),
|
||||
dot_rhs->shape().layout()));
|
||||
EXPECT_TRUE(
|
||||
LayoutUtil::Equal(LayoutUtil::MakeLayout({0}), result->shape().layout()));
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
|
||||
result->shape().layout()));
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
|
||||
}
|
||||
@ -98,9 +98,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
|
||||
// Two dot products have the same constant as the RHS, and both those dot
|
||||
// products can be optimized if the constant has a column-major layout.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0});
|
||||
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
|
||||
Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
|
||||
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0});
|
||||
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
|
||||
auto dot_a_lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, lhs_shape, "param0"));
|
||||
auto dot_b_lhs = builder.AddInstruction(
|
||||
@ -128,7 +128,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
|
||||
dot_rhs->shape().layout()));
|
||||
for (HloInstruction* instruction :
|
||||
{dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) {
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}),
|
||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
|
||||
instruction->shape().layout()));
|
||||
}
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
@ -270,13 +270,6 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
|
||||
Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
|
||||
Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
|
||||
Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
|
||||
if (m == 1) {
|
||||
dot_lhs_shape = ShapeUtil::MakeShape(F32, {k});
|
||||
dot_shape = ShapeUtil::MakeShape(F32, {n});
|
||||
} else if (n == 1) {
|
||||
dot_rhs_shape = ShapeUtil::MakeShape(F32, {k});
|
||||
dot_shape = ShapeUtil::MakeShape(F32, {m});
|
||||
}
|
||||
|
||||
HloInstruction* dot_lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
|
||||
@ -345,21 +338,16 @@ static void AssertCorrectLayoutForDotOutputFusion(
|
||||
Layout expected_dot_rhs_layout = expect_col_major_dot_rhs
|
||||
? LayoutUtil::MakeLayout({0, 1})
|
||||
: LayoutUtil::MakeLayout({1, 0});
|
||||
if (layout_assignment_result.dot_rhs_fusion_param->shape().rank() == 1) {
|
||||
expected_dot_rhs_layout = LayoutUtil::MakeLayout({0});
|
||||
}
|
||||
EXPECT_TRUE(LayoutUtil::Equal(
|
||||
expected_dot_rhs_layout,
|
||||
layout_assignment_result.dot_rhs_fusion_param->shape().layout()));
|
||||
|
||||
EXPECT_TRUE(LayoutUtil::Equal(
|
||||
LayoutUtil::MakeDescendingLayout(
|
||||
layout_assignment_result.dot_lhs_fusion_param->shape().rank()),
|
||||
LayoutUtil::MakeLayout({1, 0}),
|
||||
layout_assignment_result.dot_lhs_fusion_param->shape().layout()));
|
||||
|
||||
EXPECT_TRUE(LayoutUtil::Equal(
|
||||
LayoutUtil::MakeDescendingLayout(
|
||||
layout_assignment_result.addend_fusion_param->shape().rank()),
|
||||
LayoutUtil::MakeLayout({1, 0}),
|
||||
layout_assignment_result.addend_fusion_param->shape().layout()));
|
||||
EXPECT_THAT(computation->instructions(), Each(Not(op::Copy())));
|
||||
}
|
||||
|
@ -702,32 +702,19 @@ Status DotOpEmitter::EmitCallToRuntime() {
|
||||
}
|
||||
|
||||
DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
|
||||
CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
|
||||
CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2);
|
||||
|
||||
const Shape& lhs_shape = lhs_array_.GetShape();
|
||||
const Shape& rhs_shape = rhs_array_.GetShape();
|
||||
const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
|
||||
|
||||
auto is_column_major = [](const Shape& shape) {
|
||||
return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
|
||||
};
|
||||
|
||||
// Non-contracting dots should never make it here.
|
||||
CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
|
||||
CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
|
||||
|
||||
return {
|
||||
/*m=*/lhs_shape.rank() <= 1
|
||||
? 1LL
|
||||
: lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
|
||||
/*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)),
|
||||
/*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
|
||||
/*n=*/rhs_shape.rank() <= 1
|
||||
? 1LL
|
||||
: rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
|
||||
/*lhs_column_major=*/is_column_major(lhs_shape),
|
||||
/*lhs_canonical=*/lhs_shape.rank() <= 1 ||
|
||||
dim_nums.lhs_contracting_dimensions(0) == 1,
|
||||
/*rhs_column_major=*/is_column_major(rhs_shape),
|
||||
/*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)),
|
||||
/*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
|
||||
/*lhs_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 1,
|
||||
/*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0,
|
||||
/*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
|
||||
}
|
||||
|
||||
@ -735,7 +722,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
|
||||
// column major.
|
||||
absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
|
||||
const HloInstruction& hlo) {
|
||||
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
|
||||
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
|
||||
hlo.shape().dimensions(0) == 1) {
|
||||
if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) {
|
||||
return 1;
|
||||
}
|
||||
@ -854,10 +842,9 @@ DotImplementationStrategy GetDotImplementationStrategy(
|
||||
// Any Matrix-Vector product of floating point or integral type, or
|
||||
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
|
||||
// IR implementation.
|
||||
if ((dot_info.result_shape.dimensions_size() <= 1 ||
|
||||
(dot_info.result_shape.dimensions_size() == 2 &&
|
||||
(dot_info.result_shape.dimensions(0) == 1 ||
|
||||
dot_info.result_shape.dimensions(1) == 1))) &&
|
||||
if (dot_info.result_shape.dimensions_size() == 2 &&
|
||||
(dot_info.result_shape.dimensions(0) == 1 ||
|
||||
dot_info.result_shape.dimensions(1) == 1) &&
|
||||
(primitive_util::IsFloatingPointType(element_type) ||
|
||||
primitive_util::IsIntegralType(element_type))) {
|
||||
return DotImplementationStrategy::kTiledLlvmIrGemv;
|
||||
|
@ -3019,8 +3019,7 @@ class HloInstruction::FusionReusesParamElements {
|
||||
}
|
||||
};
|
||||
|
||||
HloInstruction::UseKind HloInstruction::OperandElementUse(
|
||||
int64 operand_num) const {
|
||||
HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kConcatenate:
|
||||
@ -3031,28 +3030,30 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(
|
||||
return UseKind::kUsePermutingElements;
|
||||
case HloOpcode::kPad:
|
||||
// Pad reuses the padding value but not the padded array elements.
|
||||
return operand_num > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
|
||||
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
|
||||
case HloOpcode::kReduce:
|
||||
// Reduce reuses the init values but not the operand array elements.
|
||||
return operand_num >= Cast<HloReduceInstruction>(this)->input_count()
|
||||
return i >= Cast<HloReduceInstruction>(this)->input_count()
|
||||
? UseKind::kReuse
|
||||
: UseKind::kUsePermutingElements;
|
||||
case HloOpcode::kFusion:
|
||||
// Uses the memoizing, recursive computation defined above.
|
||||
return FusionReusesParamElements::Compute(operand_num,
|
||||
*fused_expression_root());
|
||||
return FusionReusesParamElements::Compute(i, *fused_expression_root());
|
||||
case HloOpcode::kDot:
|
||||
// Matrix-vector dots do not reuse the matrix operand.
|
||||
if (shape().dimensions_size() <= 1) {
|
||||
if ((operand_num == 0 && operand(1)->shape().rank() <= 1) ||
|
||||
(operand_num == 1 && operand(0)->shape().rank() <= 1)) {
|
||||
// Dot operations with inputs [A,B] * [B,1] do not re-use
|
||||
// elements on their left operand.
|
||||
// Dot operations with inputs [1,A] * [A,B] do not re-use
|
||||
// elements on their right operand.
|
||||
if (shape().dimensions_size() == 2) {
|
||||
if ((i == 0 && shape().dimensions(1) == 1) ||
|
||||
(i == 1 && shape().dimensions(0) == 1)) {
|
||||
return UseKind::kUse;
|
||||
}
|
||||
}
|
||||
return UseKind::kReuse;
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
// Dynamic-update-slice reuses only start_indices.
|
||||
if (operand_num == 0 || operand_num == 1) {
|
||||
if (i == 0 || i == 1) {
|
||||
return UseKind::kUse;
|
||||
}
|
||||
return UseKind::kReuse;
|
||||
|
@ -1768,8 +1768,8 @@ class HloInstruction {
|
||||
// Removes a user for this instruction.
|
||||
void RemoveUser(HloInstruction* user);
|
||||
|
||||
// Returns how this instruction uses elements of its operand at operand_num.
|
||||
UseKind OperandElementUse(int64 operand_num) const;
|
||||
// Returns how this instruction uses elements of its `i`th operand.
|
||||
UseKind OperandElementUse(int64 i) const;
|
||||
|
||||
// Helper for implementing backend_config(). Parses backend_config_ into the
|
||||
// given proto.
|
||||
|
@ -374,11 +374,6 @@ std::vector<DotTestParam> CreateDotTestParameters() {
|
||||
}
|
||||
};
|
||||
|
||||
add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42);
|
||||
add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42);
|
||||
add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1);
|
||||
add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1);
|
||||
add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1);
|
||||
add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
|
||||
add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
|
||||
add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);
|
||||
|
@ -585,14 +585,13 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
|
||||
std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
CHECK_LE(lhs->shape().rank(), 2);
|
||||
CHECK_LE(rhs->shape().rank(), 2);
|
||||
CHECK_EQ(lhs->shape().rank(), 2);
|
||||
CHECK_EQ(rhs->shape().rank(), 2);
|
||||
PrecisionConfig precision_config;
|
||||
precision_config.mutable_operand_precision()->Resize(
|
||||
2, PrecisionConfig::DEFAULT);
|
||||
DotDimensionNumbers dot_dimension_numbers;
|
||||
dot_dimension_numbers.add_lhs_contracting_dimensions(
|
||||
lhs->shape().rank() > 1 ? 1 : 0);
|
||||
dot_dimension_numbers.add_lhs_contracting_dimensions(1);
|
||||
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
|
||||
return absl::make_unique<HloDotInstruction>(
|
||||
shape, lhs, rhs, dot_dimension_numbers, precision_config);
|
||||
|
Loading…
x
Reference in New Issue
Block a user