From 819c4dc8133f138738e0f366b3a2c490c268eb60 Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Fri, 2 Oct 2020 10:16:36 -0700 Subject: [PATCH 1/2] Enabled DNNL support for BatchMatMul with broadcast. --- .../core/kernels/mkl/mkl_batch_matmul_op.cc | 246 +++++++----------- .../core/kernels/mkl/mkl_matmul_ops_common.h | 67 ----- tensorflow/workspace.bzl | 8 +- third_party/mkl_dnn/mkldnn_v1.BUILD | 4 +- 4 files changed, 101 insertions(+), 224 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index da5a239c224..66903d8ff7a 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -15,27 +15,16 @@ limitations under the License. // See docs in ../ops/math_ops.cc. -// This file uses both oneDNN and MKL CBLAS batched xGEMM for acceleration of -// Batch Matrix-Matrix Multiplication (MatMul) operations. -// We currently register this kernel only for oneDNN supported data -// types (float, bfloat16). This file can be built with and without the use of -// the binary MKL CBLAS calls, controlled by the macro INTEL_MKL_DNN_ONLY. -// If INTEL_MKL_DNN_ONLY is defined, only oneDNN is used. For cases not -// supported by oneDNN (ex. Batchmatmul with broadcasting) we fall back to the -// default CPU implementation. -// if INTEL_MKL_DNN_ONLY is not defined, both oneDNN and MKL CBLAS -// implementations are used. This is only temporary, once we are able handle all -// cases with oneDNN, CBLAS calls will be removed. +// This file uses oneDNN library for acceleration of Batch Matrix-Matrix +// Multiplication (MatMul) operations. We currently register this kernel only +// for oneDNN supported data types (float, bfloat16). The maximum number of +// dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank +// exceeds 12, we fallback to Eigen library based kernel. #define EIGEN_USE_THREADS #if defined(INTEL_MKL) -#include -#if !defined(INTEL_MKL_DNN_ONLY) -#include "mkl_cblas.h" -#endif // !INTEL_MKL_DNN_ONLY -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -50,6 +39,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/matmul_bcast.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -100,8 +90,8 @@ class BatchMatMulMkl : public OpKernel { } // lhs and rhs can have different dimensions - const int ndims_lhs = lhs.dims(); - const int ndims_rhs = rhs.dims(); + const auto ndims_lhs = lhs.dims(); + const auto ndims_rhs = rhs.dims(); // Get broadcast info MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); @@ -111,16 +101,7 @@ class BatchMatMulMkl : public OpKernel { "In[0] and In[1] must have compatible batch dimensions: ", lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString())); -#if defined(INTEL_MKL_DNN_ONLY) - if (bcast.IsBroadcastingRequired()) { - // Calling Eigen Kernel for broadcasting case and return. Eigen does - // not have BF16 support, so we have to fail graciously in that case. - eigen_batch_mm_v2_.Compute(ctx); - return; - } -#endif // INTEL_MKL_DNN_ONLY TensorShape out_shape = bcast.output_batch_shape(); - auto batch_size = bcast.output_batch_size(); auto lhs_rows = lhs.dim_size(ndims_lhs - 2); auto lhs_cols = lhs.dim_size(ndims_lhs - 1); @@ -137,6 +118,12 @@ class BatchMatMulMkl : public OpKernel { out_shape.AddDim(lhs_rows); out_shape.AddDim(rhs_cols); + // The maximum number of dimensions for a tensor in DNNL is 12. + OP_REQUIRES(ctx, out_shape.dims() <= 12, + errors::InvalidArgument( + "Rank of output tensor is required as <= 12, ", "but is ", + out_shape.dims(), ". Current implementation supports upto ", + "rank 12 tensors.")); Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); @@ -149,75 +136,17 @@ class BatchMatMulMkl : public OpKernel { return; } - auto rhs_reshaped = rhs.template flat_inner_dims(); - auto lhs_reshaped = lhs.template flat_inner_dims(); - auto out_reshaped = out->template flat_inner_dims(); - const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1); - const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2); - const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2); - - std::vector m_array(batch_size, M); - std::vector n_array(batch_size, N); - std::vector k_array(batch_size, K); - std::vector lda_array(batch_size, adj_x_ ? M : K); - std::vector ldb_array(batch_size, adj_y_ ? K : N); - std::vector ldc_array(batch_size, N); - std::vector group_size(1, batch_size); - - bool bcast_not_supported = false; -#if defined(INTEL_MKL_DNN_ONLY) - bcast_not_supported = true; -#endif // INTEL_MKL_DNN_ONLY - if (std::is_same::value || bcast_not_supported) { - // DNNL bfloat16 API requires a, b, and c as pointers to tensors - // represented as flat-byte array. - const Scalar* a = nullptr; - const Scalar* b = nullptr; - Scalar* c = nullptr; - a = &lhs_reshaped(0, 0, 0); - b = &rhs_reshaped(0, 0, 0); - OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(), - errors::Unimplemented("Broadcasting is not supported for " - "_MklBatchMatMul yet.")); - c = &out_reshaped(0, 0, 0); - // TODO(nhasabni): Use appropriate cast instead of passing addresses of - // a,b and c. - MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, - k_array, &a, lda_array, &b, ldb_array, &c, ldc_array, 1, - group_size, ctx); - } else { - std::vector a_array; - std::vector b_array; - std::vector c_array; - a_array.reserve(batch_size); - b_array.reserve(batch_size); - c_array.reserve(batch_size); - - if (!bcast.IsBroadcastingRequired()) { - for (int64 i = 0; i < batch_size; i++) { - a_array.push_back(&lhs_reshaped(i, 0, 0)); - b_array.push_back(&rhs_reshaped(i, 0, 0)); - c_array.push_back(&out_reshaped(i, 0, 0)); - } - } else { - // Broadcasting is needed, so get the mapping from flattened output - // batch indices to x's and y's flattened batch indices. - const std::vector& a_batch_indices = bcast.x_batch_indices(); - const std::vector& b_batch_indices = bcast.y_batch_indices(); - - for (int64 i = 0; i < batch_size; i++) { - a_array.push_back(&lhs_reshaped(a_batch_indices[i], 0, 0)); - b_array.push_back(&rhs_reshaped(b_batch_indices[i], 0, 0)); - c_array.push_back(&out_reshaped(i, 0, 0)); - } - } - - // MKL CBLAS API requires a, b, and c as array of pointers, where each - // pointer is to 2D matrix. - MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, - k_array, &a_array[0], lda_array, &b_array[0], ldb_array, - &c_array[0], ldc_array, 1, group_size, ctx); - } + // Compute parameters for DNNL matmul primitive. + auto params = CreateMatMulParams(lhs.shape(), rhs.shape(), out_shape); + // Create or retrieve matmul primitive from cache. + MklMatMulPrimitive* matmul_prim = + MklMatMulPrimitiveFactory::Get( + *params, false /* value for do_not_cache */); + // Execute matmul primitive. + std::shared_ptr cpu_stream; + cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); + matmul_prim->Execute(lhs.flat().data(), rhs.flat().data(), + out->flat().data(), cpu_stream); } private: @@ -225,60 +154,78 @@ class BatchMatMulMkl : public OpKernel { bool adj_y_; BatchMatMulV2Op eigen_batch_mm_v2_; - void MklCblasGemmBatch( - const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, - const std::vector& M_Array, const std::vector& N_Array, - const std::vector& K_Array, const float** A_Array, - const std::vector& lda_Array, const float** B_Array, - const std::vector& ldb_Array, float** C_Array, - const std::vector& ldc_Array, const MKL_INT group_count, - const std::vector& group_size, OpKernelContext* ctx) { -#if !defined(INTEL_MKL_DNN_ONLY) - std::vector TransA_Array( - group_size[0], TransA ? CblasTrans : CblasNoTrans); - std::vector TransB_Array( - group_size[0], TransB ? CblasTrans : CblasNoTrans); - std::vector alpha_Array(group_size[0], 1.0); - std::vector beta_Array(group_size[0], 0.0); - cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0], - &N_Array[0], &K_Array[0], &alpha_Array[0], - reinterpret_cast(A_Array), &lda_Array[0], - reinterpret_cast(B_Array), &ldb_Array[0], - &beta_Array[0], reinterpret_cast(C_Array), - &ldc_Array[0], group_count, &group_size[0]); -#else - DCHECK(Layout == CblasRowMajor); - std::vector TransA_Array(group_size[0], TransA); - std::vector TransB_Array(group_size[0], TransB); - std::vector alpha_Array(group_size[0], 1.0); - std::vector beta_Array(group_size[0], 0.0); - dnnl_gemm_batch(TransA_Array, TransB_Array, M_Array, N_Array, - K_Array, alpha_Array, *A_Array, *B_Array, beta_Array, - *C_Array, group_count, group_size, ctx); -#endif // !INTEL_MKL_DNN_ONLY + using dims = dnnl::memory::dims; + + // This method makes the rank (ndims) of input same as the output by creating + // new axes to the input. For example, if input shape is [a, b, c, d] and + // output shape is [e, f, g, h, i, j], then the reshaped input would have a + // shape of [1, 1, a, b, c, d]. + void ExpandInputDimsToOutputShape(const TensorShape& input_shape, + const TensorShape& output_shape, + dims* reshaped_dims) { + auto ndims_input = input_shape.dims(); + auto ndims_output = output_shape.dims(); + auto dim_offset = ndims_output - ndims_input; + DCHECK(dim_offset > 0); + reshaped_dims->clear(); + reshaped_dims->resize(ndims_output, 1); + auto input_dims = input_shape.dim_sizes(); + for (int dim_idx = 0; dim_idx < ndims_input; ++dim_idx) + reshaped_dims->at(dim_idx + dim_offset) = input_dims[dim_idx]; } -// BatchMatMul BFloat16 support only exists in DNNL 1.2 onwards. -#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16) - void MklCblasGemmBatch( - const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, - const std::vector& M_Array, const std::vector& N_Array, - const std::vector& K_Array, const bfloat16** A_Array, - const std::vector& lda_Array, const bfloat16** B_Array, - const std::vector& ldb_Array, bfloat16** C_Array, - const std::vector& ldc_Array, const MKL_INT group_count, - const std::vector& group_size, OpKernelContext* ctx) { - DCHECK(Layout == CblasRowMajor); - std::vector TransA_Array(group_size[0], TransA); - std::vector TransB_Array(group_size[0], TransB); - std::vector alpha_Array(group_size[0], 1.0); - std::vector beta_Array(group_size[0], 0.0); - // TODO(nhasabni): Remove *A when we pass a, b, and c correctly. - // MKLDNN API does not require lda, ldb, and ldc. - dnnl_gemm_batch( - TransA_Array, TransB_Array, M_Array, N_Array, K_Array, alpha_Array, - *A_Array, *B_Array, beta_Array, *C_Array, group_count, group_size, ctx); + + std::unique_ptr CreateMatMulParams( + const TensorShape& lhs_shape, const TensorShape& rhs_shape, + const TensorShape& out_shape) { + const auto ndims_lhs = lhs_shape.dims(); + const auto ndims_rhs = rhs_shape.dims(); + const auto ndims_out = out_shape.dims(); + auto lhs_dims = TFShapeToMklDnnDims(lhs_shape); + auto rhs_dims = TFShapeToMklDnnDims(rhs_shape); + auto out_dims = TFShapeToMklDnnDims(out_shape); + + // DNNL matmul_primitive requires ranks of inputs and output to be same. + // Create dnnl::memory::dims for inputs and output of same rank. + // It is assumed here that MatMulBCast object creates output_batch_shape as + // a conforming superset of input batch shapes, i.e., ndims_out >= + // ndims_lhs and ndims_out >= ndims_lhs. + if (ndims_lhs < ndims_out) { + ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims); + } + if (ndims_rhs < ndims_out) { + ExpandInputDimsToOutputShape(rhs_shape, out_shape, &rhs_dims); + } + + using dim = dnnl::memory::dim; + dim m; // number of rows in x + dim k; // number of columns in x + dim n; // number of columns in y + auto lhs_strides = CalculateTFStrides(lhs_dims); + auto rhs_strides = CalculateTFStrides(rhs_dims); + auto out_strides = CalculateTFStrides(out_dims); + + if (adj_x_) { + int m_idx = ndims_out - 1; + int k_idx = ndims_out - 2; + m = lhs_dims[m_idx]; + k = lhs_dims[k_idx]; + std::swap(lhs_dims[m_idx], lhs_dims[k_idx]); + lhs_strides[m_idx] = m; + lhs_strides[k_idx] = 1; + } + + if (adj_y_) { + int k_idx = ndims_out - 1; + int n_idx = ndims_out - 2; + k = rhs_dims[k_idx]; + n = rhs_dims[n_idx]; + std::swap(rhs_dims[k_idx], rhs_dims[n_idx]); + rhs_strides[k_idx] = k; + rhs_strides[n_idx] = 1; + } + return std::make_unique( + lhs_dims, rhs_dims, out_dims, lhs_strides, rhs_strides, out_strides); } -#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16 }; #define REGISTER_BATCH_MATMUL_MKL(TYPE) \ @@ -294,14 +241,11 @@ class BatchMatMulMkl : public OpKernel { .TypeConstraint("T") \ .Label(mkl_op_registry::kMklNameChangeOpLabel), \ BatchMatMulMkl) - #ifdef ENABLE_MKL TF_CALL_float(REGISTER_BATCH_MATMUL_MKL); TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2); -#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16) TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL); TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2); -#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16 #endif // ENABLE_MKL } // end namespace tensorflow diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index e084b25f737..b77d033c9de 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -35,12 +35,6 @@ using mkldnn::stream; namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#ifdef INTEL_MKL_DNN_ONLY -// Temporarily copying some definitions from mkl_cblas.h so the same code can -// be used when calling oneDNN or CBLAS batchmatmul in mkl_batch_matmul_op.cc. -typedef enum { CblasRowMajor, CblasColumnMajor } CBLAS_LAYOUT; -#define MKL_INT int -#endif // This structure aggregates multiple inputs to MklDnnMatMul* methods. struct MklDnnMatMulFwdParams { @@ -729,67 +723,6 @@ class MklMatMulPrimitiveFactory : public MklPrimitiveFactory { } }; -template -void dnnl_gemm_batch(const std::vector& transa, - const std::vector& transb, const std::vector& m, - const std::vector& n, const std::vector& k, - const std::vector& alpha, const T* a, const T* b, - const std::vector& beta, T* c, - const int group_count, const std::vector& group_size, - OpKernelContext* ctx = nullptr) { - // Current BatchMatMul support in Tensorflow is narrower than the one offered - // by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1 - // group of size equal to batch_size, and all MatMul parameters (m, n, k, - // alpha, beta) within that group are same. - DCHECK(group_size.size() == 1); - DCHECK(transa.size() == group_size[0]); - DCHECK(transb.size() == group_size[0]); - DCHECK(alpha.size() == group_size[0]); - DCHECK(beta.size() == group_size[0]); - DCHECK(m.size() == group_size[0]); - DCHECK(n.size() == group_size[0]); - DCHECK(k.size() == group_size[0]); - for (int64_t idx = 0; idx < group_size[0]; idx++) - DCHECK(transa[0] == transa[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) - DCHECK(transb[0] == transb[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) - DCHECK(alpha[0] == alpha[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) - DCHECK(beta[0] == beta[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(m[0] == m[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(n[0] == n[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(k[0] == k[idx]); - - using dims = mkldnn::memory::dims; - // Prepare strides based on the transa and transb flags: transposed - // matrices have strides swapped BatchMatMul in MKL-DNN supports 3D metrices - // so far. That is why strides are 3D also. - dims a_sizes = dims{group_size[0], m[0], k[0]}; - dims b_sizes = dims{group_size[0], k[0], n[0]}; - dims c_sizes = dims{group_size[0], m[0], n[0]}; - dims a_strides = - !transa[0] ? dims{m[0] * k[0], k[0], 1} : dims{k[0] * m[0], 1, m[0]}; - dims b_strides = - !transb[0] ? dims{k[0] * n[0], n[0], 1} : dims{n[0] * k[0], 1, k[0]}; - dims c_strides = dims{m[0] * n[0], n[0], 1}; - - // MklMatMul uses const alpha and beta, make guarantee here to ensure - // they are never changed. - DCHECK_EQ(alpha, 1.0f); - DCHECK_EQ(beta, 0.f); - - MklMatMulParams params(a_sizes, b_sizes, c_sizes, a_strides, b_strides, - c_strides); - MklMatMulPrimitive* matmul_prim = - MklMatMulPrimitiveFactory::Get(params, 0); - - // Execute matmul primitive. - std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); - matmul_prim->Execute(a, b, c, cpu_stream); -} - template void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, const T* a, int64_t lda, const T* b, int64_t ldb, diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index cee2c29a0b0..9560353a3ae 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -209,11 +209,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "mkl_dnn_v1", build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"), - sha256 = "aef4d2a726f76f5b98902491a1a4ac69954039aa8e5a1d67ef6ce58ed00e23a6", - strip_prefix = "oneDNN-1.5.1", + sha256 = "5369f7b2f0b52b40890da50c0632c3a5d1082d98325d0f2bff125d19d0dcaa1d", + strip_prefix = "oneDNN-1.6.4", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/oneapi-src/oneDNN/archive/v1.5.1.tar.gz", - "https://github.com/oneapi-src/oneDNN/archive/v1.5.1.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/oneapi-src/oneDNN/archive/v1.6.4.tar.gz", + "https://github.com/oneapi-src/oneDNN/archive/v1.6.4.tar.gz", ], ) diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index 0e6acc2fadd..32a3fa7351b 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -58,8 +58,8 @@ template_rule( out = "include/dnnl_version.h", substitutions = { "@DNNL_VERSION_MAJOR@": "1", - "@DNNL_VERSION_MINOR@": "5", - "@DNNL_VERSION_PATCH@": "1", + "@DNNL_VERSION_MINOR@": "6", + "@DNNL_VERSION_PATCH@": "4", "@DNNL_VERSION_HASH@": "N/A", }, ) From ce41ea78005d839d25b2b4a7da03b40f3545352a Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Fri, 9 Oct 2020 09:41:17 -0700 Subject: [PATCH 2/2] Addressed review comments. --- .../core/kernels/mkl/mkl_batch_matmul_op.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index 66903d8ff7a..c56aa73b7ce 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -19,7 +19,7 @@ limitations under the License. // Multiplication (MatMul) operations. We currently register this kernel only // for oneDNN supported data types (float, bfloat16). The maximum number of // dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank -// exceeds 12, we fallback to Eigen library based kernel. +// exceeds 12, we fall back to Eigen library based kernel. #define EIGEN_USE_THREADS @@ -119,11 +119,11 @@ class BatchMatMulMkl : public OpKernel { out_shape.AddDim(lhs_rows); out_shape.AddDim(rhs_cols); // The maximum number of dimensions for a tensor in DNNL is 12. - OP_REQUIRES(ctx, out_shape.dims() <= 12, - errors::InvalidArgument( - "Rank of output tensor is required as <= 12, ", "but is ", - out_shape.dims(), ". Current implementation supports upto ", - "rank 12 tensors.")); + OP_REQUIRES( + ctx, out_shape.dims() <= 12, + errors::InvalidArgument( + "Rank of output tensor must be <= 12, but is ", out_shape.dims(), + ". Current implementation supports upto rank 12 tensors.")); Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); @@ -156,7 +156,7 @@ class BatchMatMulMkl : public OpKernel { using dims = dnnl::memory::dims; - // This method makes the rank (ndims) of input same as the output by creating + // This method makes the rank (ndims) of input same as the output by adding // new axes to the input. For example, if input shape is [a, b, c, d] and // output shape is [e, f, g, h, i, j], then the reshaped input would have a // shape of [1, 1, a, b, c, d]. @@ -188,7 +188,7 @@ class BatchMatMulMkl : public OpKernel { // Create dnnl::memory::dims for inputs and output of same rank. // It is assumed here that MatMulBCast object creates output_batch_shape as // a conforming superset of input batch shapes, i.e., ndims_out >= - // ndims_lhs and ndims_out >= ndims_lhs. + // ndims_lhs and ndims_out >= ndims_rhs. if (ndims_lhs < ndims_out) { ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims); }