diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 43539ac908f..88e6e623979 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -52,20 +52,15 @@ typedef Eigen::SyclDevice SYCLDevice; namespace { +// Returns the pair of dimensions along which to perform Tensor contraction to +// emulate matrix multiplication. +// For matrix multiplication of 2D Tensors X and Y, X is contracted along +// second dimension and Y is contracted along the first dimension (if neither X +// nor Y is adjointed). The dimension to contract along is switched when any +// operand is adjointed. +// See http://en.wikipedia.org/wiki/Tensor_contraction Eigen::IndexPair ContractionDims(bool adj_x, bool adj_y) { - if (!adj_x) { - if (!adj_y) { - return Eigen::IndexPair(1, 0); - } else { - return Eigen::IndexPair(1, 1); - } - } else { - if (!adj_y) { - return Eigen::IndexPair(0, 0); - } else { - return Eigen::IndexPair(0, 1); - } - } + return Eigen::IndexPair(adj_x ? 0 : 1, adj_y ? 1 : 0); } // Parallel batch matmul kernel based on the multi-threaded tensor contraction