Addressed review comments.
This commit is contained in:
parent
819c4dc813
commit
ce41ea7800
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
// Multiplication (MatMul) operations. We currently register this kernel only
|
// Multiplication (MatMul) operations. We currently register this kernel only
|
||||||
// for oneDNN supported data types (float, bfloat16). The maximum number of
|
// for oneDNN supported data types (float, bfloat16). The maximum number of
|
||||||
// dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank
|
// dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank
|
||||||
// exceeds 12, we fallback to Eigen library based kernel.
|
// exceeds 12, we fall back to Eigen library based kernel.
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
@ -119,11 +119,11 @@ class BatchMatMulMkl : public OpKernel {
|
|||||||
out_shape.AddDim(lhs_rows);
|
out_shape.AddDim(lhs_rows);
|
||||||
out_shape.AddDim(rhs_cols);
|
out_shape.AddDim(rhs_cols);
|
||||||
// The maximum number of dimensions for a tensor in DNNL is 12.
|
// The maximum number of dimensions for a tensor in DNNL is 12.
|
||||||
OP_REQUIRES(ctx, out_shape.dims() <= 12,
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument(
|
ctx, out_shape.dims() <= 12,
|
||||||
"Rank of output tensor is required as <= 12, ", "but is ",
|
errors::InvalidArgument(
|
||||||
out_shape.dims(), ". Current implementation supports upto ",
|
"Rank of output tensor must be <= 12, but is ", out_shape.dims(),
|
||||||
"rank 12 tensors."));
|
". Current implementation supports upto rank 12 tensors."));
|
||||||
|
|
||||||
Tensor* out = nullptr;
|
Tensor* out = nullptr;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
|
||||||
@ -156,7 +156,7 @@ class BatchMatMulMkl : public OpKernel {
|
|||||||
|
|
||||||
using dims = dnnl::memory::dims;
|
using dims = dnnl::memory::dims;
|
||||||
|
|
||||||
// This method makes the rank (ndims) of input same as the output by creating
|
// This method makes the rank (ndims) of input same as the output by adding
|
||||||
// new axes to the input. For example, if input shape is [a, b, c, d] and
|
// new axes to the input. For example, if input shape is [a, b, c, d] and
|
||||||
// output shape is [e, f, g, h, i, j], then the reshaped input would have a
|
// output shape is [e, f, g, h, i, j], then the reshaped input would have a
|
||||||
// shape of [1, 1, a, b, c, d].
|
// shape of [1, 1, a, b, c, d].
|
||||||
@ -188,7 +188,7 @@ class BatchMatMulMkl : public OpKernel {
|
|||||||
// Create dnnl::memory::dims for inputs and output of same rank.
|
// Create dnnl::memory::dims for inputs and output of same rank.
|
||||||
// It is assumed here that MatMulBCast object creates output_batch_shape as
|
// It is assumed here that MatMulBCast object creates output_batch_shape as
|
||||||
// a conforming superset of input batch shapes, i.e., ndims_out >=
|
// a conforming superset of input batch shapes, i.e., ndims_out >=
|
||||||
// ndims_lhs and ndims_out >= ndims_lhs.
|
// ndims_lhs and ndims_out >= ndims_rhs.
|
||||||
if (ndims_lhs < ndims_out) {
|
if (ndims_lhs < ndims_out) {
|
||||||
ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims);
|
ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user