Polish some comments in dot decomposer.

PiperOrigin-RevId: 316544755
Change-Id: I46ce48dcbf64119e4795b923f5b45b814e8bb8c7
This commit is contained in:
A. Unique TensorFlower 2020-06-15 14:26:51 -07:00 committed by TensorFlower Gardener
parent 813bd968d0
commit 92bd07fee5

View File

@ -29,10 +29,12 @@ namespace xla {
namespace {
// Convert a dot into a canonical form where non-contracting and contracting
// dimensions are reshaped together and batch dimensions are the most major
// dimensions. This requires transposing and reshapes of the lhs and rhs and
// reshaping the output batch to the original shape.
// Convert a dot into a canonical form;
// * Non-contracting dimensions are reshaped together,
// * Contracting dimensions are reshaped together,
// * Batch dimensions are the most major dimensions.
// This requires transposing and reshaping of the lhs and rhs, and reshaping the
// output batch to the original shape.
Status CanonicalizeDot(HloInstruction* original_dot) {
auto computation = original_dot->parent();
const auto& original_dnums = original_dot->dot_dimension_numbers();
@ -63,7 +65,8 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
}
}
// The canonical form of the lhs is
// [BatchDims, NonContractingDims, ContractingsDims]
// [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct]
// If NonContractingDimsProduct is 1, it is omitted.
std::vector<int64> lhs_transpose;
lhs_transpose.reserve(lhs_rank);
lhs_transpose.insert(lhs_transpose.end(),
@ -109,7 +112,8 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
}
// The canonical form of the rhs is
// [BatchDims, ContractingsDims, NonContractingDims]
// [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct]
// If NonContractingDimsProduct is 1, it is omitted.
std::vector<int64> rhs_transpose;
rhs_transpose.reserve(rhs_rank);
rhs_transpose.insert(rhs_transpose.end(),