Addressed review comments.

This commit is contained in:
mdfaijul 2020-10-09 09:41:17 -07:00
parent 819c4dc813
commit ce41ea7800

View File

@ -19,7 +19,7 @@ limitations under the License.
// Multiplication (MatMul) operations. We currently register this kernel only
// for oneDNN supported data types (float, bfloat16). The maximum number of
// dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank
// exceeds 12, we fallback to Eigen library based kernel.
// exceeds 12, we fall back to Eigen library based kernel.
#define EIGEN_USE_THREADS
@ -119,11 +119,11 @@ class BatchMatMulMkl : public OpKernel {
out_shape.AddDim(lhs_rows);
out_shape.AddDim(rhs_cols);
// The maximum number of dimensions for a tensor in DNNL is 12.
OP_REQUIRES(ctx, out_shape.dims() <= 12,
errors::InvalidArgument(
"Rank of output tensor is required as <= 12, ", "but is ",
out_shape.dims(), ". Current implementation supports upto ",
"rank 12 tensors."));
OP_REQUIRES(
ctx, out_shape.dims() <= 12,
errors::InvalidArgument(
"Rank of output tensor must be <= 12, but is ", out_shape.dims(),
". Current implementation supports upto rank 12 tensors."));
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
@ -156,7 +156,7 @@ class BatchMatMulMkl : public OpKernel {
using dims = dnnl::memory::dims;
// This method makes the rank (ndims) of input same as the output by creating
// This method makes the rank (ndims) of input same as the output by adding
// new axes to the input. For example, if input shape is [a, b, c, d] and
// output shape is [e, f, g, h, i, j], then the reshaped input would have a
// shape of [1, 1, a, b, c, d].
@ -188,7 +188,7 @@ class BatchMatMulMkl : public OpKernel {
// Create dnnl::memory::dims for inputs and output of same rank.
// It is assumed here that MatMulBCast object creates output_batch_shape as
// a conforming superset of input batch shapes, i.e., ndims_out >=
// ndims_lhs and ndims_out >= ndims_lhs.
// ndims_lhs and ndims_out >= ndims_rhs.
if (ndims_lhs < ndims_out) {
ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims);
}