Minor cleanup: compute contraction dimensions in a single line instead of four

in the batch matmul op kernel.

PiperOrigin-RevId: 237890589
This commit is contained in:
Anudhyan Boral 2019-03-11 15:03:02 -07:00 committed by TensorFlower Gardener
parent df3a9e555f
commit 73cdb00c26

View File

@ -52,20 +52,15 @@ typedef Eigen::SyclDevice SYCLDevice;
namespace { namespace {
// Returns the pair of dimensions along which to perform Tensor contraction to
// emulate matrix multiplication.
// For matrix multiplication of 2D Tensors X and Y, X is contracted along
// second dimension and Y is contracted along the first dimension (if neither X
// nor Y is adjointed). The dimension to contract along is switched when any
// operand is adjointed.
// See http://en.wikipedia.org/wiki/Tensor_contraction
Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) { Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
if (!adj_x) { return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
if (!adj_y) {
return Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
} else {
return Eigen::IndexPair<Eigen::DenseIndex>(1, 1);
}
} else {
if (!adj_y) {
return Eigen::IndexPair<Eigen::DenseIndex>(0, 0);
} else {
return Eigen::IndexPair<Eigen::DenseIndex>(0, 1);
}
}
} }
// Parallel batch matmul kernel based on the multi-threaded tensor contraction // Parallel batch matmul kernel based on the multi-threaded tensor contraction