Automated rollback of commit 76e256261ba4feee18aeb9a2e69d3d5c861ad976

PiperOrigin-RevId: 253207758
This commit is contained in:
A. Unique TensorFlower 2019-06-14 04:51:35 -07:00 committed by TensorFlower Gardener
parent 76e256261b
commit f45c45d906
9 changed files with 73 additions and 106 deletions

View File

@ -1783,13 +1783,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
// If the lhs or rhs have only batch and contracting dimensions, a dot can be
// rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
if (options_.enable_dot_strength_reduction() &&
((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
lhs->shape().rank()) ||
(dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
rhs->shape().rank()))) {
if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
lhs->shape().rank()) ||
(dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
rhs->shape().rank())) {
TF_ASSIGN_OR_RETURN(
HloInstruction * new_lhs,
NormalizeDotOperandToBatchMajorAndContractingMinor(
@ -1885,10 +1884,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
return ReplaceInstruction(dot, dot_of_gather_optimized);
}
TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
RemoveDegenerateDimensionFromDot(dot));
if (removed_degenerate_dimensions) {
return Status::OK();
if (options_.enable_dot_strength_reduction()) {
TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
RemoveDegenerateDimensionFromDot(dot));
if (removed_degenerate_dimensions) {
return Status::OK();
}
}
// Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).

View File

@ -22,6 +22,11 @@ namespace cpu {
namespace {
int64 BytesInDimension(const Shape& shape, int64 dimension) {
return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
shape.dimensions(dimension);
}
bool CanBeLoopFused(const HloInstruction& hlo) {
// These are the only ones we fuse since we rely on effective elemental IR
// generation.
@ -42,7 +47,8 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
bool IsNonComplexMatrixVectorDot(const HloInstruction* hlo) {
const Shape& hlo_shape = hlo->shape();
return !ShapeUtil::ElementIsComplex(hlo_shape) &&
hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1;
hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 &&
(hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1);
}
bool HasExactlyOneUse(const HloInstruction& hlo_instr) {
@ -130,21 +136,18 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
// fusion can easily be overshadowed by the overhead of a naive GEMM
// algorithm in the IR.
const Shape& output_shape = consumer->shape();
if (output_shape.dimensions_size() <= 1) {
// We fuse in cases where we have a matrix*vector or vector*matrix dot and
if (output_shape.dimensions_size() == 2) {
// We fuse in cases where we have dot([A,B],[B,1]) or dot([1,A],[A,B]) and
// fusion can get rid of the larger tensor. We assume that a naive
// traversal of a small enough (to fit in L1) column or row tensor is
// "good enough" from the perspective of cache management; and calling out
// to an optimized GEMM kernel is not a huge win.
if (consumer->operand(0)->shape().rank() == 1 && operand_index == 1 &&
ShapeUtil::ByteSizeOfElements(consumer->operand(0)->shape()) <
kFusionThresholdBytes) {
if (output_shape.dimensions(0) == 1 && operand_index == 1 &&
BytesInDimension(output_shape, 1) < kFusionThresholdBytes) {
VLOG(2) << "Fusing small matrix-vector product.";
return true;
} else if (consumer->operand(1)->shape().rank() == 1 &&
operand_index == 0 &&
ShapeUtil::ByteSizeOfElements(consumer->operand(1)->shape()) <
kFusionThresholdBytes) {
} else if (output_shape.dimensions(1) == 1 && operand_index == 0 &&
BytesInDimension(output_shape, 0) < kFusionThresholdBytes) {
VLOG(2) << "Fusing small matrix-vector product.";
return true;
}

View File

@ -51,12 +51,12 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0"));
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0));
HloInstruction* dot = builder.AddInstruction(
MakeDot(ShapeUtil::MakeShape(F32, {1024}), exp0, arg1));
MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1));
auto module = CreateNewUnverifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
@ -68,14 +68,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
HloComputation::Builder builder(TestName());
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {256}), "arg0"));
0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0"));
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
HloInstruction* dot = builder.AddInstruction(
MakeDot(ShapeUtil::MakeShape(F32, {1024}), arg0, exp1));
MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1));
auto module = CreateNewUnverifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
@ -89,14 +89,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0));
HloInstruction* dot = builder.AddInstruction(
MakeDot(ShapeUtil::MakeShape(F32, {1024}), bitcast0, arg1));
MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1));
auto module = CreateNewUnverifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
@ -110,7 +110,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
@ -118,7 +118,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1024, 256}), exp0));
HloInstruction* dot = builder.AddInstruction(
MakeDot(ShapeUtil::MakeShape(F32, {1024}), reshape0, arg1));
MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1));
auto module = CreateNewUnverifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
@ -130,14 +130,14 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) {
HloComputation::Builder builder(TestName());
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {32 * 1024}), "arg0"));
0, ShapeUtil::MakeShape(F32, {1, 32 * 1024}), "arg0"));
HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1"));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1));
HloInstruction* dot = builder.AddInstruction(
MakeDot(ShapeUtil::MakeShape(F32, {32 * 1024}), arg0, exp1));
MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1));
auto module = CreateNewUnverifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
@ -637,13 +637,6 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name,
Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
if (m == 1) {
dot_lhs_shape = ShapeUtil::MakeShape(F32, {k});
dot_shape = ShapeUtil::MakeShape(F32, {n});
} else if (n == 1) {
dot_rhs_shape = ShapeUtil::MakeShape(F32, {k});
dot_shape = ShapeUtil::MakeShape(F32, {m});
}
auto* dot_lhs = builder.AddInstruction(
HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));

View File

@ -63,9 +63,9 @@ class CpuLayoutAssignmentTest : public HloTestBase {
TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
auto builder = HloComputation::Builder(TestName());
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0});
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0});
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
auto dot_lhs = builder.AddInstruction(
HloInstruction::CreateParameter(0, lhs_shape, "param0"));
auto dot_rhs = builder.AddInstruction(
@ -83,12 +83,12 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
AssignLayouts(module.get(), &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}),
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
dot_lhs->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}),
dot_rhs->shape().layout()));
EXPECT_TRUE(
LayoutUtil::Equal(LayoutUtil::MakeLayout({0}), result->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
result->shape().layout()));
for (const auto& instruction : computation->instructions()) {
EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
}
@ -98,9 +98,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
// Two dot products have the same constant as the RHS, and both those dot
// products can be optimized if the constant has a column-major layout.
auto builder = HloComputation::Builder(TestName());
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0});
Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {24}, {0});
Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
auto dot_a_lhs = builder.AddInstruction(
HloInstruction::CreateParameter(0, lhs_shape, "param0"));
auto dot_b_lhs = builder.AddInstruction(
@ -128,7 +128,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
dot_rhs->shape().layout()));
for (HloInstruction* instruction :
{dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) {
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0}),
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
instruction->shape().layout()));
}
for (const auto& instruction : computation->instructions()) {
@ -270,13 +270,6 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
if (m == 1) {
dot_lhs_shape = ShapeUtil::MakeShape(F32, {k});
dot_shape = ShapeUtil::MakeShape(F32, {n});
} else if (n == 1) {
dot_rhs_shape = ShapeUtil::MakeShape(F32, {k});
dot_shape = ShapeUtil::MakeShape(F32, {m});
}
HloInstruction* dot_lhs = builder.AddInstruction(
HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
@ -345,21 +338,16 @@ static void AssertCorrectLayoutForDotOutputFusion(
Layout expected_dot_rhs_layout = expect_col_major_dot_rhs
? LayoutUtil::MakeLayout({0, 1})
: LayoutUtil::MakeLayout({1, 0});
if (layout_assignment_result.dot_rhs_fusion_param->shape().rank() == 1) {
expected_dot_rhs_layout = LayoutUtil::MakeLayout({0});
}
EXPECT_TRUE(LayoutUtil::Equal(
expected_dot_rhs_layout,
layout_assignment_result.dot_rhs_fusion_param->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
LayoutUtil::MakeDescendingLayout(
layout_assignment_result.dot_lhs_fusion_param->shape().rank()),
LayoutUtil::MakeLayout({1, 0}),
layout_assignment_result.dot_lhs_fusion_param->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
LayoutUtil::MakeDescendingLayout(
layout_assignment_result.addend_fusion_param->shape().rank()),
LayoutUtil::MakeLayout({1, 0}),
layout_assignment_result.addend_fusion_param->shape().layout()));
EXPECT_THAT(computation->instructions(), Each(Not(op::Copy())));
}

View File

@ -702,32 +702,19 @@ Status DotOpEmitter::EmitCallToRuntime() {
}
DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
CHECK_LE(dot_info_.result_shape.dimensions_size(), 2);
CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2);
const Shape& lhs_shape = lhs_array_.GetShape();
const Shape& rhs_shape = rhs_array_.GetShape();
const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;
auto is_column_major = [](const Shape& shape) {
return shape.rank() > 1 && LayoutUtil::Minor(shape.layout(), 0) == 0;
};
// Non-contracting dots should never make it here.
CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);
return {
/*m=*/lhs_shape.rank() <= 1
? 1LL
: lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
/*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)),
/*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
/*n=*/rhs_shape.rank() <= 1
? 1LL
: rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
/*lhs_column_major=*/is_column_major(lhs_shape),
/*lhs_canonical=*/lhs_shape.rank() <= 1 ||
dim_nums.lhs_contracting_dimensions(0) == 1,
/*rhs_column_major=*/is_column_major(rhs_shape),
/*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)),
/*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
/*lhs_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 1,
/*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0,
/*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
}
@ -735,7 +722,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
// column major.
absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
const HloInstruction& hlo) {
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() <= 1) {
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
hlo.shape().dimensions(0) == 1) {
if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) {
return 1;
}
@ -854,10 +842,9 @@ DotImplementationStrategy GetDotImplementationStrategy(
// Any Matrix-Vector product of floating point or integral type, or
// a transpose-dot fusion of the same can be lowered to a tiled LLVM
// IR implementation.
if ((dot_info.result_shape.dimensions_size() <= 1 ||
(dot_info.result_shape.dimensions_size() == 2 &&
(dot_info.result_shape.dimensions(0) == 1 ||
dot_info.result_shape.dimensions(1) == 1))) &&
if (dot_info.result_shape.dimensions_size() == 2 &&
(dot_info.result_shape.dimensions(0) == 1 ||
dot_info.result_shape.dimensions(1) == 1) &&
(primitive_util::IsFloatingPointType(element_type) ||
primitive_util::IsIntegralType(element_type))) {
return DotImplementationStrategy::kTiledLlvmIrGemv;

View File

@ -3019,8 +3019,7 @@ class HloInstruction::FusionReusesParamElements {
}
};
HloInstruction::UseKind HloInstruction::OperandElementUse(
int64 operand_num) const {
HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
switch (opcode_) {
case HloOpcode::kBitcast:
case HloOpcode::kConcatenate:
@ -3031,28 +3030,30 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(
return UseKind::kUsePermutingElements;
case HloOpcode::kPad:
// Pad reuses the padding value but not the padded array elements.
return operand_num > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
case HloOpcode::kReduce:
// Reduce reuses the init values but not the operand array elements.
return operand_num >= Cast<HloReduceInstruction>(this)->input_count()
return i >= Cast<HloReduceInstruction>(this)->input_count()
? UseKind::kReuse
: UseKind::kUsePermutingElements;
case HloOpcode::kFusion:
// Uses the memoizing, recursive computation defined above.
return FusionReusesParamElements::Compute(operand_num,
*fused_expression_root());
return FusionReusesParamElements::Compute(i, *fused_expression_root());
case HloOpcode::kDot:
// Matrix-vector dots do not reuse the matrix operand.
if (shape().dimensions_size() <= 1) {
if ((operand_num == 0 && operand(1)->shape().rank() <= 1) ||
(operand_num == 1 && operand(0)->shape().rank() <= 1)) {
// Dot operations with inputs [A,B] * [B,1] do not re-use
// elements on their left operand.
// Dot operations with inputs [1,A] * [A,B] do not re-use
// elements on their right operand.
if (shape().dimensions_size() == 2) {
if ((i == 0 && shape().dimensions(1) == 1) ||
(i == 1 && shape().dimensions(0) == 1)) {
return UseKind::kUse;
}
}
return UseKind::kReuse;
case HloOpcode::kDynamicUpdateSlice:
// Dynamic-update-slice reuses only start_indices.
if (operand_num == 0 || operand_num == 1) {
if (i == 0 || i == 1) {
return UseKind::kUse;
}
return UseKind::kReuse;

View File

@ -1768,8 +1768,8 @@ class HloInstruction {
// Removes a user for this instruction.
void RemoveUser(HloInstruction* user);
// Returns how this instruction uses elements of its operand at operand_num.
UseKind OperandElementUse(int64 operand_num) const;
// Returns how this instruction uses elements of its `i`th operand.
UseKind OperandElementUse(int64 i) const;
// Helper for implementing backend_config(). Parses backend_config_ into the
// given proto.

View File

@ -374,11 +374,6 @@ std::vector<DotTestParam> CreateDotTestParameters() {
}
};
add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42);
add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42);
add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1);
add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1);
add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1);
add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);

View File

@ -585,14 +585,13 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
HloInstruction* lhs,
HloInstruction* rhs) {
CHECK_LE(lhs->shape().rank(), 2);
CHECK_LE(rhs->shape().rank(), 2);
CHECK_EQ(lhs->shape().rank(), 2);
CHECK_EQ(rhs->shape().rank(), 2);
PrecisionConfig precision_config;
precision_config.mutable_operand_precision()->Resize(
2, PrecisionConfig::DEFAULT);
DotDimensionNumbers dot_dimension_numbers;
dot_dimension_numbers.add_lhs_contracting_dimensions(
lhs->shape().rank() > 1 ? 1 : 0);
dot_dimension_numbers.add_lhs_contracting_dimensions(1);
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
return absl::make_unique<HloDotInstruction>(
shape, lhs, rhs, dot_dimension_numbers, precision_config);