Minor cleanup: compute contraction dimensions in a single line instead of four
in the batch matmul op kernel. PiperOrigin-RevId: 237890589
This commit is contained in:
parent
df3a9e555f
commit
73cdb00c26
@ -52,20 +52,15 @@ typedef Eigen::SyclDevice SYCLDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns the pair of dimensions along which to perform Tensor contraction to
|
||||
// emulate matrix multiplication.
|
||||
// For matrix multiplication of 2D Tensors X and Y, X is contracted along
|
||||
// second dimension and Y is contracted along the first dimension (if neither X
|
||||
// nor Y is adjointed). The dimension to contract along is switched when any
|
||||
// operand is adjointed.
|
||||
// See http://en.wikipedia.org/wiki/Tensor_contraction
|
||||
Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
|
||||
if (!adj_x) {
|
||||
if (!adj_y) {
|
||||
return Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
|
||||
} else {
|
||||
return Eigen::IndexPair<Eigen::DenseIndex>(1, 1);
|
||||
}
|
||||
} else {
|
||||
if (!adj_y) {
|
||||
return Eigen::IndexPair<Eigen::DenseIndex>(0, 0);
|
||||
} else {
|
||||
return Eigen::IndexPair<Eigen::DenseIndex>(0, 1);
|
||||
}
|
||||
}
|
||||
return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
|
||||
}
|
||||
|
||||
// Parallel batch matmul kernel based on the multi-threaded tensor contraction
|
||||
|
Loading…
Reference in New Issue
Block a user