Automated rollback of commit 76e256261ba4feee18aeb9a2e69d3d5c861ad976
PiperOrigin-RevId: 253207758
This commit is contained in:
parent
76e256261b
commit
f45c45d906
@ -1783,13 +1783,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
|||||||
|
|
||||||
// If the lhs or rhs have only batch and contracting dimensions, a dot can be
|
// If the lhs or rhs have only batch and contracting dimensions, a dot can be
|
||||||
// rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
|
// rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
|
||||||
if (options_.enable_dot_strength_reduction() &&
|
if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
|
||||||
((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
|
dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
|
||||||
dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
|
lhs->shape().rank()) ||
|
||||||
lhs->shape().rank()) ||
|
(dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
|
||||||
(dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
|
dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
|
||||||
dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
|
rhs->shape().rank())) {
|
||||||
rhs->shape().rank()))) {
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HloInstruction * new_lhs,
|
HloInstruction * new_lhs,
|
||||||
NormalizeDotOperandToBatchMajorAndContractingMinor(
|
NormalizeDotOperandToBatchMajorAndContractingMinor(
|
||||||
@ -1885,10 +1884,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
|||||||
return ReplaceInstruction(dot, dot_of_gather_optimized);
|
return ReplaceInstruction(dot, dot_of_gather_optimized);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
|
if (options_.enable_dot_strength_reduction()) {
|
||||||
RemoveDegenerateDimensionFromDot(dot));
|
TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
|
||||||
if (removed_degenerate_dimensions) {
|
RemoveDegenerateDimensionFromDot(dot));
|
||||||
return Status::OK();
|
if (removed_degenerate_dimensions) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
|
// Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
|
||||||
|
@ -22,6 +22,11 @@ namespace cpu {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
int64 BytesInDimension(const Shape& shape, int64 dimension) {
|
||||||
|
return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
|
||||||
|
shape.dimensions(dimension);
|
||||||
|
}
|
||||||
|
|
||||||
bool CanBeLoopFused(const HloInstruction& hlo) {
|
bool CanBeLoopFused(const HloInstruction& hlo) {
|
||||||
// These are the only ones we fuse since we rely on effective elemental IR
|
// These are the only ones we fuse since we rely on effective elemental IR
|
||||||
// generation.
|
// generation.
|
||||||
@ -42,7 +47,8 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
|
|||||||
bool IsNonComplexMatrixVectorDot(const HloInstruction* hlo) {
|
bool IsNonComplexMatrixVectorDot(const HloInstruction* hlo) {
|
||||||
const Shape& hlo_shape = hlo->shape();
|
const Shape& hlo_shape = hlo->shape();
|
||||||
return !ShapeUtil::ElementIsComplex(hlo_shape) &&
|
return !ShapeUtil::ElementIsComplex(hlo_shape) &&
|
||||||
hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1;
|
hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 &&
|
||||||
|
(hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HasExactlyOneUse(const HloInstruction& hlo_instr) {
|
bool HasExactlyOneUse(const HloInstruction& hlo_instr) {
|
||||||
@ -130,21 +136,18 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
|||||||
// fusion can easily be overshadowed by the overhead of a naive GEMM
|
// fusion can easily be overshadowed by the overhead of a naive GEMM
|
||||||
// algorithm in the IR.
|
// algorithm in the IR.
|
||||||
const Shape& output_shape = consumer->shape();
|
const Shape& output_shape = consumer->shape();
|
||||||
if (output_shape.dimensions_size() <= 1) {
|
if (output_shape.dimensions_size() == 2) {
|
||||||
// We fuse in cases where we have a matrix*vector or vector*matrix dot and
|
// We fuse in cases where we have dot([A,B],[B,1]) or dot([1,A],[A,B]) and
|
||||||
// fusion can get rid of the larger tensor. We assume that a naive
|
// fusion can get rid of the larger tensor. We assume that a naive
|
||||||
// traversal of a small enough (to fit in L1) column or row tensor is
|
// traversal of a small enough (to fit in L1) column or row tensor is
|
||||||
// "good enough" from the perspective of cache management; and calling out
|
// "good enough" from the perspective of cache management; and calling out
|
||||||
// to an optimized GEMM kernel is not a huge win.
|
// to an optimized GEMM kernel is not a huge win.
|
||||||
if (consumer->operand(0)->shape().rank() == 1 && operand_index == 1 &&
|
if (output_shape.dimensions(0) == 1 && operand_index == 1 &&
|
||||||
ShapeUtil::ByteSizeOfElements(consumer->operand(0)->shape()) <
|
BytesInDimension(output_shape, 1) < kFusionThresholdBytes) {
|
||||||
kFusionThresholdBytes) {
|
|
||||||
VLOG(2) << "Fusing small matrix-vector product.";
|
VLOG(2) << "Fusing small matrix-vector product.";
|
||||||
return true;
|
return true;
|
||||||
} else if (consumer->operand(1)->shape().rank() == 1 &&
|
} else if (output_shape.dimensions(1) == 1 && operand_index == 0 &&
|
||||||
operand_index == 0 &&
|
BytesInDimension(output_shape, 0) < kFusionThresholdBytes) {
|
||||||
ShapeUtil::ByteSizeOfElements(consumer->operand(1)->shape()) <
|
|
||||||
kFusionThresholdBytes) {
|
|
||||||
VLOG(2) << "Fusing small matrix-vector product.";
|
VLOG(2) << "Fusing small matrix-vector product.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -51,12 +51,12 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
|
|||||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0"));
|
0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0"));
|
||||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
|
1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
|
||||||
|
|
||||||
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||||
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0));
|
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0));
|
||||||
HloInstruction* dot = builder.AddInstruction(
|
HloInstruction* dot = builder.AddInstruction(
|
||||||
MakeDot(ShapeUtil::MakeShape(F32, {1024}), exp0, arg1));
|
MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1));
|
||||||
|
|
||||||
auto module = CreateNewUnverifiedModule();
|
auto module = CreateNewUnverifiedModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
@ -68,14 +68,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
|
|||||||
TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
|
TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
0, ShapeUtil::MakeShape(F32, {256}), "arg0"));
|
0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0"));
|
||||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
|
1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
|
||||||
|
|
||||||
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||||
ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
|
ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
|
||||||
HloInstruction* dot = builder.AddInstruction(
|
HloInstruction* dot = builder.AddInstruction(
|
||||||
MakeDot(ShapeUtil::MakeShape(F32, {1024}), arg0, exp1));
|
MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1));
|
||||||
|
|
||||||
auto module = CreateNewUnverifiedModule();
|
auto module = CreateNewUnverifiedModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
@ -89,14 +89,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
|
|||||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
|
0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
|
||||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
|
1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
|
||||||
|
|
||||||
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||||
ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
|
ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
|
||||||
HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||||
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0));
|
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0));
|
||||||
HloInstruction* dot = builder.AddInstruction(
|
HloInstruction* dot = builder.AddInstruction(
|
||||||
MakeDot(ShapeUtil::MakeShape(F32, {1024}), bitcast0, arg1));
|
MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1));
|
||||||
|
|
||||||
auto module = CreateNewUnverifiedModule();
|
auto module = CreateNewUnverifiedModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
@ -110,7 +110,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
|
|||||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
|
0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
|
||||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
|
1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
|
||||||
|
|
||||||
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||||
ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
|
ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
|
||||||
@ -118,7 +118,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
|
|||||||
builder.AddInstruction(HloInstruction::CreateReshape(
|
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(S32, {1024, 256}), exp0));
|
ShapeUtil::MakeShape(S32, {1024, 256}), exp0));
|
||||||
HloInstruction* dot = builder.AddInstruction(
|
HloInstruction* dot = builder.AddInstruction(
|
||||||
MakeDot(ShapeUtil::MakeShape(F32, {1024}), reshape0, arg1));
|
MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1));
|
||||||
|
|
||||||
auto module = CreateNewUnverifiedModule();
|
auto module = CreateNewUnverifiedModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
@ -130,14 +130,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
|
|||||||
TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) {
|
TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
0, ShapeUtil::MakeShape(F32, {32 * 1024}), "arg0"));
|
0, ShapeUtil::MakeShape(F32, {1, 32 * 1024}), "arg0"));
|
||||||
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1"));
|
1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1"));
|
||||||
|
|
||||||
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||||
ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1));
|
ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1));
|
||||||
HloInstruction* dot = builder.AddInstruction(
|
HloInstruction* dot = builder.AddInstruction(
|
||||||
MakeDot(ShapeUtil::MakeShape(F32, {32 * 1024}), arg0, exp1));
|
MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1));
|
||||||
|
|
||||||
auto module = CreateNewUnverifiedModule();
|
auto module = CreateNewUnverifiedModule();
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
@ -637,13 +637,6 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name,
|
|||||||
Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
|
Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
|
||||||
Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
|
Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
|
||||||
Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
|
Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
|
||||||
if (m == 1) {
|
|
||||||
dot_lhs_shape = ShapeUtil::MakeShape(F32, {k});
|
|
||||||
dot_shape = ShapeUtil::MakeShape(F32, {n});
|
|
||||||
} else if (n == 1) {
|
|
||||||
dot_rhs_shape = ShapeUtil::MakeShape(F32, {k});
|
|
||||||
dot_shape = ShapeUtil::MakeShape(F32, {m});
|
|
||||||
}
|
|
||||||
|
|
||||||
auto* dot_lhs = builder.AddInstruction(
|
auto* dot_lhs = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
|
HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
|
||||||
|
@ -63,9 +63,9 @@ class CpuLayoutAssignmentTest : public HloTestBase {
|
|||||||
|
|
||||||
TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
|
TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0});
|
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
|
||||||
Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
|
||||||
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0});
|
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
|
||||||
auto dot_lhs = builder.AddInstruction(
|
auto dot_lhs = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, lhs_shape, "param0"));
|
HloInstruction::CreateParameter(0, lhs_shape, "param0"));
|
||||||
auto dot_rhs = builder.AddInstruction(
|
auto dot_rhs = builder.AddInstruction(
|
||||||
@ -83,12 +83,12 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
|
|||||||
ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
|
ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
|
||||||
AssignLayouts(module.get(), &computation_layout);
|
AssignLayouts(module.get(), &computation_layout);
|
||||||
|
|
||||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}),
|
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
|
||||||
dot_lhs->shape().layout()));
|
dot_lhs->shape().layout()));
|
||||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}),
|
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}),
|
||||||
dot_rhs->shape().layout()));
|
dot_rhs->shape().layout()));
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
|
||||||
LayoutUtil::Equal(LayoutUtil::MakeLayout({0}), result->shape().layout()));
|
result->shape().layout()));
|
||||||
for (const auto& instruction : computation->instructions()) {
|
for (const auto& instruction : computation->instructions()) {
|
||||||
EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
|
EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
|
||||||
}
|
}
|
||||||
@ -98,9 +98,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
|
|||||||
// Two dot products have the same constant as the RHS, and both those dot
|
// Two dot products have the same constant as the RHS, and both those dot
|
||||||
// products can be optimized if the constant has a column-major layout.
|
// products can be optimized if the constant has a column-major layout.
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0});
|
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
|
||||||
Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
|
||||||
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0});
|
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
|
||||||
auto dot_a_lhs = builder.AddInstruction(
|
auto dot_a_lhs = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, lhs_shape, "param0"));
|
HloInstruction::CreateParameter(0, lhs_shape, "param0"));
|
||||||
auto dot_b_lhs = builder.AddInstruction(
|
auto dot_b_lhs = builder.AddInstruction(
|
||||||
@ -128,7 +128,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
|
|||||||
dot_rhs->shape().layout()));
|
dot_rhs->shape().layout()));
|
||||||
for (HloInstruction* instruction :
|
for (HloInstruction* instruction :
|
||||||
{dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) {
|
{dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) {
|
||||||
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}),
|
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
|
||||||
instruction->shape().layout()));
|
instruction->shape().layout()));
|
||||||
}
|
}
|
||||||
for (const auto& instruction : computation->instructions()) {
|
for (const auto& instruction : computation->instructions()) {
|
||||||
@ -270,13 +270,6 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
|
|||||||
Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
|
Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
|
||||||
Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
|
Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
|
||||||
Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
|
Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
|
||||||
if (m == 1) {
|
|
||||||
dot_lhs_shape = ShapeUtil::MakeShape(F32, {k});
|
|
||||||
dot_shape = ShapeUtil::MakeShape(F32, {n});
|
|
||||||
} else if (n == 1) {
|
|
||||||
dot_rhs_shape = ShapeUtil::MakeShape(F32, {k});
|
|
||||||
dot_shape = ShapeUtil::MakeShape(F32, {m});
|
|
||||||
}
|
|
||||||
|
|
||||||
HloInstruction* dot_lhs = builder.AddInstruction(
|
HloInstruction* dot_lhs = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
|
HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
|
||||||
@ -345,21 +338,16 @@ static void AssertCorrectLayoutForDotOutputFusion(
|
|||||||
Layout expected_dot_rhs_layout = expect_col_major_dot_rhs
|
Layout expected_dot_rhs_layout = expect_col_major_dot_rhs
|
||||||
? LayoutUtil::MakeLayout({0, 1})
|
? LayoutUtil::MakeLayout({0, 1})
|
||||||
: LayoutUtil::MakeLayout({1, 0});
|
: LayoutUtil::MakeLayout({1, 0});
|
||||||
if (layout_assignment_result.dot_rhs_fusion_param->shape().rank() == 1) {
|
|
||||||
expected_dot_rhs_layout = LayoutUtil::MakeLayout({0});
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(LayoutUtil::Equal(
|
EXPECT_TRUE(LayoutUtil::Equal(
|
||||||
expected_dot_rhs_layout,
|
expected_dot_rhs_layout,
|
||||||
layout_assignment_result.dot_rhs_fusion_param->shape().layout()));
|
layout_assignment_result.dot_rhs_fusion_param->shape().layout()));
|
||||||
|
|
||||||
EXPECT_TRUE(LayoutUtil::Equal(
|
EXPECT_TRUE(LayoutUtil::Equal(
|
||||||
LayoutUtil::MakeDescendingLayout(
|
LayoutUtil::MakeLayout({1, 0}),
|
||||||
layout_assignment_result.dot_lhs_fusion_param->shape().rank()),
|
|
||||||
layout_assignment_result.dot_lhs_fusion_param->shape().layout()));
|
layout_assignment_result.dot_lhs_fusion_param->shape().layout()));
|
||||||
|
|
||||||
EXPECT_TRUE(LayoutUtil::Equal(
|
EXPECT_TRUE(LayoutUtil::Equal(
|
||||||
LayoutUtil::MakeDescendingLayout(
|
LayoutUtil::MakeLayout({1, 0}),
|
||||||
layout_assignment_result.addend_fusion_param->shape().rank()),
|
|
||||||
layout_assignment_result.addend_fusion_param->shape().layout()));
|
layout_assignment_result.addend_fusion_param->shape().layout()));
|
||||||
EXPECT_THAT(computation->instructions(), Each(Not(op::Copy())));
|
EXPECT_THAT(computation->instructions(), Each(Not(op::Copy())));
|
||||||
}
|
}
|
||||||
|
@ -702,32 +702,19 @@ Status DotOpEmitter::EmitCallToRuntime() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
|
DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
|
||||||
CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
|
CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2);
|
||||||
|
|
||||||
const Shape& lhs_shape = lhs_array_.GetShape();
|
const Shape& lhs_shape = lhs_array_.GetShape();
|
||||||
const Shape& rhs_shape = rhs_array_.GetShape();
|
const Shape& rhs_shape = rhs_array_.GetShape();
|
||||||
const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
|
const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
|
||||||
|
|
||||||
auto is_column_major = [](const Shape& shape) {
|
|
||||||
return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Non-contracting dots should never make it here.
|
|
||||||
CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
|
|
||||||
CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
/*m=*/lhs_shape.rank() <= 1
|
/*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)),
|
||||||
? 1LL
|
|
||||||
: lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
|
|
||||||
/*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
|
/*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
|
||||||
/*n=*/rhs_shape.rank() <= 1
|
/*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)),
|
||||||
? 1LL
|
/*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
|
||||||
: rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
|
/*lhs_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 1,
|
||||||
/*lhs_column_major=*/is_column_major(lhs_shape),
|
/*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0,
|
||||||
/*lhs_canonical=*/lhs_shape.rank() <= 1 ||
|
|
||||||
dim_nums.lhs_contracting_dimensions(0) == 1,
|
|
||||||
/*rhs_column_major=*/is_column_major(rhs_shape),
|
|
||||||
/*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
|
/*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -735,7 +722,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
|
|||||||
// column major.
|
// column major.
|
||||||
absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
|
absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
|
||||||
const HloInstruction& hlo) {
|
const HloInstruction& hlo) {
|
||||||
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
|
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
|
||||||
|
hlo.shape().dimensions(0) == 1) {
|
||||||
if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) {
|
if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -854,10 +842,9 @@ DotImplementationStrategy GetDotImplementationStrategy(
|
|||||||
// Any Matrix-Vector product of floating point or integral type, or
|
// Any Matrix-Vector product of floating point or integral type, or
|
||||||
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
|
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
|
||||||
// IR implementation.
|
// IR implementation.
|
||||||
if ((dot_info.result_shape.dimensions_size() <= 1 ||
|
if (dot_info.result_shape.dimensions_size() == 2 &&
|
||||||
(dot_info.result_shape.dimensions_size() == 2 &&
|
(dot_info.result_shape.dimensions(0) == 1 ||
|
||||||
(dot_info.result_shape.dimensions(0) == 1 ||
|
dot_info.result_shape.dimensions(1) == 1) &&
|
||||||
dot_info.result_shape.dimensions(1) == 1))) &&
|
|
||||||
(primitive_util::IsFloatingPointType(element_type) ||
|
(primitive_util::IsFloatingPointType(element_type) ||
|
||||||
primitive_util::IsIntegralType(element_type))) {
|
primitive_util::IsIntegralType(element_type))) {
|
||||||
return DotImplementationStrategy::kTiledLlvmIrGemv;
|
return DotImplementationStrategy::kTiledLlvmIrGemv;
|
||||||
|
@ -3019,8 +3019,7 @@ class HloInstruction::FusionReusesParamElements {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
HloInstruction::UseKind HloInstruction::OperandElementUse(
|
HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
|
||||||
int64 operand_num) const {
|
|
||||||
switch (opcode_) {
|
switch (opcode_) {
|
||||||
case HloOpcode::kBitcast:
|
case HloOpcode::kBitcast:
|
||||||
case HloOpcode::kConcatenate:
|
case HloOpcode::kConcatenate:
|
||||||
@ -3031,28 +3030,30 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(
|
|||||||
return UseKind::kUsePermutingElements;
|
return UseKind::kUsePermutingElements;
|
||||||
case HloOpcode::kPad:
|
case HloOpcode::kPad:
|
||||||
// Pad reuses the padding value but not the padded array elements.
|
// Pad reuses the padding value but not the padded array elements.
|
||||||
return operand_num > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
|
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
|
||||||
case HloOpcode::kReduce:
|
case HloOpcode::kReduce:
|
||||||
// Reduce reuses the init values but not the operand array elements.
|
// Reduce reuses the init values but not the operand array elements.
|
||||||
return operand_num >= Cast<HloReduceInstruction>(this)->input_count()
|
return i >= Cast<HloReduceInstruction>(this)->input_count()
|
||||||
? UseKind::kReuse
|
? UseKind::kReuse
|
||||||
: UseKind::kUsePermutingElements;
|
: UseKind::kUsePermutingElements;
|
||||||
case HloOpcode::kFusion:
|
case HloOpcode::kFusion:
|
||||||
// Uses the memoizing, recursive computation defined above.
|
// Uses the memoizing, recursive computation defined above.
|
||||||
return FusionReusesParamElements::Compute(operand_num,
|
return FusionReusesParamElements::Compute(i, *fused_expression_root());
|
||||||
*fused_expression_root());
|
|
||||||
case HloOpcode::kDot:
|
case HloOpcode::kDot:
|
||||||
// Matrix-vector dots do not reuse the matrix operand.
|
// Dot operations with inputs [A,B] * [B,1] do not re-use
|
||||||
if (shape().dimensions_size() <= 1) {
|
// elements on their left operand.
|
||||||
if ((operand_num == 0 && operand(1)->shape().rank() <= 1) ||
|
// Dot operations with inputs [1,A] * [A,B] do not re-use
|
||||||
(operand_num == 1 && operand(0)->shape().rank() <= 1)) {
|
// elements on their right operand.
|
||||||
|
if (shape().dimensions_size() == 2) {
|
||||||
|
if ((i == 0 && shape().dimensions(1) == 1) ||
|
||||||
|
(i == 1 && shape().dimensions(0) == 1)) {
|
||||||
return UseKind::kUse;
|
return UseKind::kUse;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return UseKind::kReuse;
|
return UseKind::kReuse;
|
||||||
case HloOpcode::kDynamicUpdateSlice:
|
case HloOpcode::kDynamicUpdateSlice:
|
||||||
// Dynamic-update-slice reuses only start_indices.
|
// Dynamic-update-slice reuses only start_indices.
|
||||||
if (operand_num == 0 || operand_num == 1) {
|
if (i == 0 || i == 1) {
|
||||||
return UseKind::kUse;
|
return UseKind::kUse;
|
||||||
}
|
}
|
||||||
return UseKind::kReuse;
|
return UseKind::kReuse;
|
||||||
|
@ -1768,8 +1768,8 @@ class HloInstruction {
|
|||||||
// Removes a user for this instruction.
|
// Removes a user for this instruction.
|
||||||
void RemoveUser(HloInstruction* user);
|
void RemoveUser(HloInstruction* user);
|
||||||
|
|
||||||
// Returns how this instruction uses elements of its operand at operand_num.
|
// Returns how this instruction uses elements of its `i`th operand.
|
||||||
UseKind OperandElementUse(int64 operand_num) const;
|
UseKind OperandElementUse(int64 i) const;
|
||||||
|
|
||||||
// Helper for implementing backend_config(). Parses backend_config_ into the
|
// Helper for implementing backend_config(). Parses backend_config_ into the
|
||||||
// given proto.
|
// given proto.
|
||||||
|
@ -374,11 +374,6 @@ std::vector<DotTestParam> CreateDotTestParameters() {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42);
|
|
||||||
add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42);
|
|
||||||
add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1);
|
|
||||||
add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1);
|
|
||||||
add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1);
|
|
||||||
add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
|
add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
|
||||||
add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
|
add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
|
||||||
add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);
|
add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);
|
||||||
|
@ -585,14 +585,13 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
|
|||||||
std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
|
std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
|
||||||
HloInstruction* lhs,
|
HloInstruction* lhs,
|
||||||
HloInstruction* rhs) {
|
HloInstruction* rhs) {
|
||||||
CHECK_LE(lhs->shape().rank(), 2);
|
CHECK_EQ(lhs->shape().rank(), 2);
|
||||||
CHECK_LE(rhs->shape().rank(), 2);
|
CHECK_EQ(rhs->shape().rank(), 2);
|
||||||
PrecisionConfig precision_config;
|
PrecisionConfig precision_config;
|
||||||
precision_config.mutable_operand_precision()->Resize(
|
precision_config.mutable_operand_precision()->Resize(
|
||||||
2, PrecisionConfig::DEFAULT);
|
2, PrecisionConfig::DEFAULT);
|
||||||
DotDimensionNumbers dot_dimension_numbers;
|
DotDimensionNumbers dot_dimension_numbers;
|
||||||
dot_dimension_numbers.add_lhs_contracting_dimensions(
|
dot_dimension_numbers.add_lhs_contracting_dimensions(1);
|
||||||
lhs->shape().rank() > 1 ? 1 : 0);
|
|
||||||
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
|
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
|
||||||
return absl::make_unique<HloDotInstruction>(
|
return absl::make_unique<HloDotInstruction>(
|
||||||
shape, lhs, rhs, dot_dimension_numbers, precision_config);
|
shape, lhs, rhs, dot_dimension_numbers, precision_config);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user