Merge pull request #43741 from Intel-tensorflow:amin/dnnl-batchmatmul

PiperOrigin-RevId: 336737523
Change-Id: I6a0525beb2618b1ac59adc1d486da1524639497c
This commit is contained in:
TensorFlower Gardener 2020-10-12 14:11:23 -07:00
commit 30429491d7
4 changed files with 100 additions and 223 deletions

View File

@ -15,26 +15,16 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.
// This file uses both oneDNN and MKL CBLAS batched xGEMM for acceleration of
// Batch Matrix-Matrix Multiplication (MatMul) operations.
// We currently register this kernel only for oneDNN supported data
// types (float, bfloat16). This file can be built with and without the use of
// the binary MKL CBLAS calls, controlled by the macro INTEL_MKL_DNN_ONLY.
// If INTEL_MKL_DNN_ONLY is defined, only oneDNN is used. For cases not
// supported by oneDNN (ex. Batchmatmul with broadcasting) we fall back to the
// default CPU implementation.
// if INTEL_MKL_DNN_ONLY is not defined, both oneDNN and MKL CBLAS
// implementations are used. This is only temporary, once we are able handle all
// cases with oneDNN, CBLAS calls will be removed.
// This file uses oneDNN library for acceleration of Batch Matrix-Matrix
// Multiplication (MatMul) operations. We currently register this kernel only
// for oneDNN supported data types (float, bfloat16). The maximum number of
// dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank
// exceeds 12, we fall back to Eigen library based kernel.
#define EIGEN_USE_THREADS
#if defined(INTEL_MKL)
#include <vector>
#if !defined(INTEL_MKL_DNN_ONLY)
#include "mkl_cblas.h"
#endif // !INTEL_MKL_DNN_ONLY
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -100,8 +90,8 @@ class BatchMatMulMkl : public OpKernel {
}
// lhs and rhs can have different dimensions
const int ndims_lhs = lhs.dims();
const int ndims_rhs = rhs.dims();
const auto ndims_lhs = lhs.dims();
const auto ndims_rhs = rhs.dims();
// Get broadcast info
MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes());
@ -111,16 +101,7 @@ class BatchMatMulMkl : public OpKernel {
"In[0] and In[1] must have compatible batch dimensions: ",
lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString()));
#if defined(INTEL_MKL_DNN_ONLY)
if (bcast.IsBroadcastingRequired()) {
// Calling Eigen Kernel for broadcasting case and return. Eigen does
// not have BF16 support, so we have to fail graciously in that case.
eigen_batch_mm_v2_.Compute(ctx);
return;
}
#endif // INTEL_MKL_DNN_ONLY
TensorShape out_shape = bcast.output_batch_shape();
auto batch_size = bcast.output_batch_size();
auto lhs_rows = lhs.dim_size(ndims_lhs - 2);
auto lhs_cols = lhs.dim_size(ndims_lhs - 1);
@ -137,6 +118,12 @@ class BatchMatMulMkl : public OpKernel {
out_shape.AddDim(lhs_rows);
out_shape.AddDim(rhs_cols);
// The maximum number of dimensions for a tensor in DNNL is 12.
OP_REQUIRES(
ctx, out_shape.dims() <= 12,
errors::InvalidArgument(
"Rank of output tensor must be <= 12, but is ", out_shape.dims(),
". Current implementation supports upto rank 12 tensors."));
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
@ -149,75 +136,17 @@ class BatchMatMulMkl : public OpKernel {
return;
}
auto rhs_reshaped = rhs.template flat_inner_dims<Scalar, 3>();
auto lhs_reshaped = lhs.template flat_inner_dims<Scalar, 3>();
auto out_reshaped = out->template flat_inner_dims<Scalar, 3>();
const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1);
const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2);
const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2);
std::vector<MKL_INT> m_array(batch_size, M);
std::vector<MKL_INT> n_array(batch_size, N);
std::vector<MKL_INT> k_array(batch_size, K);
std::vector<MKL_INT> lda_array(batch_size, adj_x_ ? M : K);
std::vector<MKL_INT> ldb_array(batch_size, adj_y_ ? K : N);
std::vector<MKL_INT> ldc_array(batch_size, N);
std::vector<MKL_INT> group_size(1, batch_size);
bool bcast_not_supported = false;
#if defined(INTEL_MKL_DNN_ONLY)
bcast_not_supported = true;
#endif // INTEL_MKL_DNN_ONLY
if (std::is_same<Scalar, bfloat16>::value || bcast_not_supported) {
// DNNL bfloat16 API requires a, b, and c as pointers to tensors
// represented as flat-byte array.
const Scalar* a = nullptr;
const Scalar* b = nullptr;
Scalar* c = nullptr;
a = &lhs_reshaped(0, 0, 0);
b = &rhs_reshaped(0, 0, 0);
OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(),
errors::Unimplemented("Broadcasting is not supported for "
"_MklBatchMatMul yet."));
c = &out_reshaped(0, 0, 0);
// TODO(nhasabni): Use appropriate cast instead of passing addresses of
// a,b and c.
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array,
k_array, &a, lda_array, &b, ldb_array, &c, ldc_array, 1,
group_size, ctx);
} else {
std::vector<const Scalar*> a_array;
std::vector<const Scalar*> b_array;
std::vector<Scalar*> c_array;
a_array.reserve(batch_size);
b_array.reserve(batch_size);
c_array.reserve(batch_size);
if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; i++) {
a_array.push_back(&lhs_reshaped(i, 0, 0));
b_array.push_back(&rhs_reshaped(i, 0, 0));
c_array.push_back(&out_reshaped(i, 0, 0));
}
} else {
// Broadcasting is needed, so get the mapping from flattened output
// batch indices to x's and y's flattened batch indices.
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
for (int64 i = 0; i < batch_size; i++) {
a_array.push_back(&lhs_reshaped(a_batch_indices[i], 0, 0));
b_array.push_back(&rhs_reshaped(b_batch_indices[i], 0, 0));
c_array.push_back(&out_reshaped(i, 0, 0));
}
}
// MKL CBLAS API requires a, b, and c as array of pointers, where each
// pointer is to 2D matrix.
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array,
k_array, &a_array[0], lda_array, &b_array[0], ldb_array,
&c_array[0], ldc_array, 1, group_size, ctx);
}
// Compute parameters for DNNL matmul primitive.
auto params = CreateMatMulParams(lhs.shape(), rhs.shape(), out_shape);
// Create or retrieve matmul primitive from cache.
MklMatMulPrimitive<Scalar>* matmul_prim =
MklMatMulPrimitiveFactory<Scalar>::Get(
*params, false /* value for do_not_cache */);
// Execute matmul primitive.
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
matmul_prim->Execute(lhs.flat<Scalar>().data(), rhs.flat<Scalar>().data(),
out->flat<Scalar>().data(), cpu_stream);
}
private:
@ -225,60 +154,78 @@ class BatchMatMulMkl : public OpKernel {
bool adj_y_;
BatchMatMulV2Op<CPUDevice, Scalar> eigen_batch_mm_v2_;
void MklCblasGemmBatch(
const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB,
const std::vector<MKL_INT>& M_Array, const std::vector<MKL_INT>& N_Array,
const std::vector<MKL_INT>& K_Array, const float** A_Array,
const std::vector<MKL_INT>& lda_Array, const float** B_Array,
const std::vector<MKL_INT>& ldb_Array, float** C_Array,
const std::vector<MKL_INT>& ldc_Array, const MKL_INT group_count,
const std::vector<MKL_INT>& group_size, OpKernelContext* ctx) {
#if !defined(INTEL_MKL_DNN_ONLY)
std::vector<CBLAS_TRANSPOSE> TransA_Array(
group_size[0], TransA ? CblasTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_Array(
group_size[0], TransB ? CblasTrans : CblasNoTrans);
std::vector<float> alpha_Array(group_size[0], 1.0);
std::vector<float> beta_Array(group_size[0], 0.0);
cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0],
&N_Array[0], &K_Array[0], &alpha_Array[0],
reinterpret_cast<const float**>(A_Array), &lda_Array[0],
reinterpret_cast<const float**>(B_Array), &ldb_Array[0],
&beta_Array[0], reinterpret_cast<float**>(C_Array),
&ldc_Array[0], group_count, &group_size[0]);
#else
DCHECK(Layout == CblasRowMajor);
std::vector<bool> TransA_Array(group_size[0], TransA);
std::vector<bool> TransB_Array(group_size[0], TransB);
std::vector<float> alpha_Array(group_size[0], 1.0);
std::vector<float> beta_Array(group_size[0], 0.0);
dnnl_gemm_batch<float>(TransA_Array, TransB_Array, M_Array, N_Array,
K_Array, alpha_Array, *A_Array, *B_Array, beta_Array,
*C_Array, group_count, group_size, ctx);
#endif // !INTEL_MKL_DNN_ONLY
using dims = dnnl::memory::dims;
// This method makes the rank (ndims) of input same as the output by adding
// new axes to the input. For example, if input shape is [a, b, c, d] and
// output shape is [e, f, g, h, i, j], then the reshaped input would have a
// shape of [1, 1, a, b, c, d].
void ExpandInputDimsToOutputShape(const TensorShape& input_shape,
const TensorShape& output_shape,
dims* reshaped_dims) {
auto ndims_input = input_shape.dims();
auto ndims_output = output_shape.dims();
auto dim_offset = ndims_output - ndims_input;
DCHECK(dim_offset > 0);
reshaped_dims->clear();
reshaped_dims->resize(ndims_output, 1);
auto input_dims = input_shape.dim_sizes();
for (int dim_idx = 0; dim_idx < ndims_input; ++dim_idx)
reshaped_dims->at(dim_idx + dim_offset) = input_dims[dim_idx];
}
// BatchMatMul BFloat16 support only exists in DNNL 1.2 onwards.
#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16)
void MklCblasGemmBatch(
const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB,
const std::vector<MKL_INT>& M_Array, const std::vector<MKL_INT>& N_Array,
const std::vector<MKL_INT>& K_Array, const bfloat16** A_Array,
const std::vector<MKL_INT>& lda_Array, const bfloat16** B_Array,
const std::vector<MKL_INT>& ldb_Array, bfloat16** C_Array,
const std::vector<MKL_INT>& ldc_Array, const MKL_INT group_count,
const std::vector<MKL_INT>& group_size, OpKernelContext* ctx) {
DCHECK(Layout == CblasRowMajor);
std::vector<bool> TransA_Array(group_size[0], TransA);
std::vector<bool> TransB_Array(group_size[0], TransB);
std::vector<float> alpha_Array(group_size[0], 1.0);
std::vector<float> beta_Array(group_size[0], 0.0);
// TODO(nhasabni): Remove *A when we pass a, b, and c correctly.
// MKLDNN API does not require lda, ldb, and ldc.
dnnl_gemm_batch<bfloat16>(
TransA_Array, TransB_Array, M_Array, N_Array, K_Array, alpha_Array,
*A_Array, *B_Array, beta_Array, *C_Array, group_count, group_size, ctx);
std::unique_ptr<MklMatMulParams> CreateMatMulParams(
const TensorShape& lhs_shape, const TensorShape& rhs_shape,
const TensorShape& out_shape) {
const auto ndims_lhs = lhs_shape.dims();
const auto ndims_rhs = rhs_shape.dims();
const auto ndims_out = out_shape.dims();
auto lhs_dims = TFShapeToMklDnnDims(lhs_shape);
auto rhs_dims = TFShapeToMklDnnDims(rhs_shape);
auto out_dims = TFShapeToMklDnnDims(out_shape);
// DNNL matmul_primitive requires ranks of inputs and output to be same.
// Create dnnl::memory::dims for inputs and output of same rank.
// It is assumed here that MatMulBCast object creates output_batch_shape as
// a conforming superset of input batch shapes, i.e., ndims_out >=
// ndims_lhs and ndims_out >= ndims_rhs.
if (ndims_lhs < ndims_out) {
ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims);
}
if (ndims_rhs < ndims_out) {
ExpandInputDimsToOutputShape(rhs_shape, out_shape, &rhs_dims);
}
using dim = dnnl::memory::dim;
dim m; // number of rows in x
dim k; // number of columns in x
dim n; // number of columns in y
auto lhs_strides = CalculateTFStrides(lhs_dims);
auto rhs_strides = CalculateTFStrides(rhs_dims);
auto out_strides = CalculateTFStrides(out_dims);
if (adj_x_) {
int m_idx = ndims_out - 1;
int k_idx = ndims_out - 2;
m = lhs_dims[m_idx];
k = lhs_dims[k_idx];
std::swap(lhs_dims[m_idx], lhs_dims[k_idx]);
lhs_strides[m_idx] = m;
lhs_strides[k_idx] = 1;
}
if (adj_y_) {
int k_idx = ndims_out - 1;
int n_idx = ndims_out - 2;
k = rhs_dims[k_idx];
n = rhs_dims[n_idx];
std::swap(rhs_dims[k_idx], rhs_dims[n_idx]);
rhs_strides[k_idx] = k;
rhs_strides[n_idx] = 1;
}
return std::make_unique<MklMatMulParams>(
lhs_dims, rhs_dims, out_dims, lhs_strides, rhs_strides, out_strides);
}
#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16
};
#define REGISTER_BATCH_MATMUL_MKL(TYPE) \
@ -294,14 +241,11 @@ class BatchMatMulMkl : public OpKernel {
.TypeConstraint<TYPE>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
BatchMatMulMkl<CPUDevice, TYPE, true>)
#ifdef ENABLE_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2);
#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16)
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2);
#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16
#endif // ENABLE_MKL
} // end namespace tensorflow

View File

@ -35,12 +35,6 @@ using mkldnn::stream;
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
#ifdef INTEL_MKL_DNN_ONLY
// Temporarily copying some definitions from mkl_cblas.h so the same code can
// be used when calling oneDNN or CBLAS batchmatmul in mkl_batch_matmul_op.cc.
typedef enum { CblasRowMajor, CblasColumnMajor } CBLAS_LAYOUT;
#define MKL_INT int
#endif
// This structure aggregates multiple inputs to MklDnnMatMul* methods.
struct MklDnnMatMulFwdParams {
@ -729,67 +723,6 @@ class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
}
};
template <typename T>
void dnnl_gemm_batch(const std::vector<bool>& transa,
const std::vector<bool>& transb, const std::vector<int>& m,
const std::vector<int>& n, const std::vector<int>& k,
const std::vector<float>& alpha, const T* a, const T* b,
const std::vector<float>& beta, T* c,
const int group_count, const std::vector<int>& group_size,
OpKernelContext* ctx = nullptr) {
// Current BatchMatMul support in Tensorflow is narrower than the one offered
// by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1
// group of size equal to batch_size, and all MatMul parameters (m, n, k,
// alpha, beta) within that group are same.
DCHECK(group_size.size() == 1);
DCHECK(transa.size() == group_size[0]);
DCHECK(transb.size() == group_size[0]);
DCHECK(alpha.size() == group_size[0]);
DCHECK(beta.size() == group_size[0]);
DCHECK(m.size() == group_size[0]);
DCHECK(n.size() == group_size[0]);
DCHECK(k.size() == group_size[0]);
for (int64_t idx = 0; idx < group_size[0]; idx++)
DCHECK(transa[0] == transa[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++)
DCHECK(transb[0] == transb[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++)
DCHECK(alpha[0] == alpha[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++)
DCHECK(beta[0] == beta[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(m[0] == m[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(n[0] == n[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(k[0] == k[idx]);
using dims = mkldnn::memory::dims;
// Prepare strides based on the transa and transb flags: transposed
// matrices have strides swapped BatchMatMul in MKL-DNN supports 3D metrices
// so far. That is why strides are 3D also.
dims a_sizes = dims{group_size[0], m[0], k[0]};
dims b_sizes = dims{group_size[0], k[0], n[0]};
dims c_sizes = dims{group_size[0], m[0], n[0]};
dims a_strides =
!transa[0] ? dims{m[0] * k[0], k[0], 1} : dims{k[0] * m[0], 1, m[0]};
dims b_strides =
!transb[0] ? dims{k[0] * n[0], n[0], 1} : dims{n[0] * k[0], 1, k[0]};
dims c_strides = dims{m[0] * n[0], n[0], 1};
// MklMatMul uses const alpha and beta, make guarantee here to ensure
// they are never changed.
DCHECK_EQ(alpha, 1.0f);
DCHECK_EQ(beta, 0.f);
MklMatMulParams params(a_sizes, b_sizes, c_sizes, a_strides, b_strides,
c_strides);
MklMatMulPrimitive<T>* matmul_prim =
MklMatMulPrimitiveFactory<T>::Get(params, 0);
// Execute matmul primitive.
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
matmul_prim->Execute(a, b, c, cpu_stream);
}
template <typename T>
void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,

View File

@ -209,11 +209,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "mkl_dnn_v1",
build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"),
sha256 = "aef4d2a726f76f5b98902491a1a4ac69954039aa8e5a1d67ef6ce58ed00e23a6",
strip_prefix = "oneDNN-1.5.1",
sha256 = "5369f7b2f0b52b40890da50c0632c3a5d1082d98325d0f2bff125d19d0dcaa1d",
strip_prefix = "oneDNN-1.6.4",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/oneapi-src/oneDNN/archive/v1.5.1.tar.gz",
"https://github.com/oneapi-src/oneDNN/archive/v1.5.1.tar.gz",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/oneapi-src/oneDNN/archive/v1.6.4.tar.gz",
"https://github.com/oneapi-src/oneDNN/archive/v1.6.4.tar.gz",
],
)

View File

@ -58,8 +58,8 @@ template_rule(
out = "include/dnnl_version.h",
substitutions = {
"@DNNL_VERSION_MAJOR@": "1",
"@DNNL_VERSION_MINOR@": "5",
"@DNNL_VERSION_PATCH@": "1",
"@DNNL_VERSION_MINOR@": "6",
"@DNNL_VERSION_PATCH@": "4",
"@DNNL_VERSION_HASH@": "N/A",
},
)