Addressed review comments.

This commit is contained in:
mdfaijul 2020-10-09 09:41:17 -07:00
parent 819c4dc813
commit ce41ea7800

View File

@ -19,7 +19,7 @@ limitations under the License.
// Multiplication (MatMul) operations. We currently register this kernel only // Multiplication (MatMul) operations. We currently register this kernel only
// for oneDNN supported data types (float, bfloat16). The maximum number of // for oneDNN supported data types (float, bfloat16). The maximum number of
// dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank // dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank
// exceeds 12, we fallback to Eigen library based kernel. // exceeds 12, we fall back to Eigen library based kernel.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
@ -119,11 +119,11 @@ class BatchMatMulMkl : public OpKernel {
out_shape.AddDim(lhs_rows); out_shape.AddDim(lhs_rows);
out_shape.AddDim(rhs_cols); out_shape.AddDim(rhs_cols);
// The maximum number of dimensions for a tensor in DNNL is 12. // The maximum number of dimensions for a tensor in DNNL is 12.
OP_REQUIRES(ctx, out_shape.dims() <= 12, OP_REQUIRES(
errors::InvalidArgument( ctx, out_shape.dims() <= 12,
"Rank of output tensor is required as <= 12, ", "but is ", errors::InvalidArgument(
out_shape.dims(), ". Current implementation supports upto ", "Rank of output tensor must be <= 12, but is ", out_shape.dims(),
"rank 12 tensors.")); ". Current implementation supports upto rank 12 tensors."));
Tensor* out = nullptr; Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
@ -156,7 +156,7 @@ class BatchMatMulMkl : public OpKernel {
using dims = dnnl::memory::dims; using dims = dnnl::memory::dims;
// This method makes the rank (ndims) of input same as the output by creating // This method makes the rank (ndims) of input same as the output by adding
// new axes to the input. For example, if input shape is [a, b, c, d] and // new axes to the input. For example, if input shape is [a, b, c, d] and
// output shape is [e, f, g, h, i, j], then the reshaped input would have a // output shape is [e, f, g, h, i, j], then the reshaped input would have a
// shape of [1, 1, a, b, c, d]. // shape of [1, 1, a, b, c, d].
@ -188,7 +188,7 @@ class BatchMatMulMkl : public OpKernel {
// Create dnnl::memory::dims for inputs and output of same rank. // Create dnnl::memory::dims for inputs and output of same rank.
// It is assumed here that MatMulBCast object creates output_batch_shape as // It is assumed here that MatMulBCast object creates output_batch_shape as
// a conforming superset of input batch shapes, i.e., ndims_out >= // a conforming superset of input batch shapes, i.e., ndims_out >=
// ndims_lhs and ndims_out >= ndims_lhs. // ndims_lhs and ndims_out >= ndims_rhs.
if (ndims_lhs < ndims_out) { if (ndims_lhs < ndims_out) {
ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims); ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims);
} }