Addressed review comments.
This commit is contained in:
parent
819c4dc813
commit
ce41ea7800
@ -19,7 +19,7 @@ limitations under the License.
|
||||
// Multiplication (MatMul) operations. We currently register this kernel only
|
||||
// for oneDNN supported data types (float, bfloat16). The maximum number of
|
||||
// dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank
|
||||
// exceeds 12, we fallback to Eigen library based kernel.
|
||||
// exceeds 12, we fall back to Eigen library based kernel.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
@ -119,11 +119,11 @@ class BatchMatMulMkl : public OpKernel {
|
||||
out_shape.AddDim(lhs_rows);
|
||||
out_shape.AddDim(rhs_cols);
|
||||
// The maximum number of dimensions for a tensor in DNNL is 12.
|
||||
OP_REQUIRES(ctx, out_shape.dims() <= 12,
|
||||
errors::InvalidArgument(
|
||||
"Rank of output tensor is required as <= 12, ", "but is ",
|
||||
out_shape.dims(), ". Current implementation supports upto ",
|
||||
"rank 12 tensors."));
|
||||
OP_REQUIRES(
|
||||
ctx, out_shape.dims() <= 12,
|
||||
errors::InvalidArgument(
|
||||
"Rank of output tensor must be <= 12, but is ", out_shape.dims(),
|
||||
". Current implementation supports upto rank 12 tensors."));
|
||||
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
|
||||
@ -156,7 +156,7 @@ class BatchMatMulMkl : public OpKernel {
|
||||
|
||||
using dims = dnnl::memory::dims;
|
||||
|
||||
// This method makes the rank (ndims) of input same as the output by creating
|
||||
// This method makes the rank (ndims) of input same as the output by adding
|
||||
// new axes to the input. For example, if input shape is [a, b, c, d] and
|
||||
// output shape is [e, f, g, h, i, j], then the reshaped input would have a
|
||||
// shape of [1, 1, a, b, c, d].
|
||||
@ -188,7 +188,7 @@ class BatchMatMulMkl : public OpKernel {
|
||||
// Create dnnl::memory::dims for inputs and output of same rank.
|
||||
// It is assumed here that MatMulBCast object creates output_batch_shape as
|
||||
// a conforming superset of input batch shapes, i.e., ndims_out >=
|
||||
// ndims_lhs and ndims_out >= ndims_lhs.
|
||||
// ndims_lhs and ndims_out >= ndims_rhs.
|
||||
if (ndims_lhs < ndims_out) {
|
||||
ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user