Merge pull request #40128 from Intel-tensorflow:sriniva2/tp_batch_matmul

PiperOrigin-RevId: 315324346
Change-Id: Ieb6918ad7ef5f9ac59773a15fc8eca9fa8c8ef16
This commit is contained in:
TensorFlower Gardener 2020-06-08 12:31:59 -07:00
commit 6ceeae8697
4 changed files with 62 additions and 96 deletions

View File

@ -394,10 +394,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.batch_matmul,
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back({csinfo_.batch_matmul_v2,
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
rinfo_.push_back(
{csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});

View File

@ -25,7 +25,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY)
#if defined(INTEL_MKL)
#include <vector>
#include "mkl_cblas.h"
@ -54,7 +54,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Device, typename Scalar, bool v2_bcast>
class BatchMatMulMkl : public OpKernel {
public:
explicit BatchMatMulMkl(OpKernelConstruction* context) : OpKernel(context) {
explicit BatchMatMulMkl(OpKernelConstruction* context)
: OpKernel(context), eigen_batch_mm_v2_(context) {
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
}
@ -104,6 +105,14 @@ class BatchMatMulMkl : public OpKernel {
"In[0] and In[1] must have compatible batch dimensions: ",
lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString()));
#ifdef ENABLE_MKLDNN_THREADPOOL
if (bcast.IsBroadcastingRequired()) {
// Calling Eigen Kernel for broadcasting case and return. Eigen does
// not have BF16 support, so we have to fail graciously in that case.
eigen_batch_mm_v2_.Compute(ctx);
return;
}
#endif // ENABLE_MKLDNN_THREADPOOL
TensorShape out_shape = bcast.output_batch_shape();
auto batch_size = bcast.output_batch_size();
@ -149,22 +158,27 @@ class BatchMatMulMkl : public OpKernel {
std::vector<MKL_INT> ldc_array(batch_size, N);
std::vector<MKL_INT> group_size(1, batch_size);
if (std::is_same<Scalar, bfloat16>::value) {
bool threadpool_enabled = false;
#ifdef ENABLE_MKLDNN_THREADPOOL
threadpool_enabled = true;
#endif // ENABLE_MKLDNN_THREADPOOL
if (std::is_same<Scalar, bfloat16>::value || threadpool_enabled) {
// DNNL bfloat16 API requires a, b, and c as pointers to tensors
// represented as flat-byte array.
const Scalar* a = nullptr;
const Scalar* b = nullptr;
OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(),
errors::Unimplemented("Broadcasting is not supported for "
"BFloat16 _MklBatchMatMul yet."));
Scalar* c = nullptr;
a = &lhs_reshaped(0, 0, 0);
b = &rhs_reshaped(0, 0, 0);
Scalar* c = &out_reshaped(0, 0, 0);
OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(),
errors::Unimplemented("Broadcasting is not supported for "
"_MklBatchMatMul yet."));
c = &out_reshaped(0, 0, 0);
// TODO(nhasabni): Use appropriate cast instead of passing addresses of
// a,b and c.
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array,
k_array, &a, lda_array, &b, ldb_array, &c, ldc_array, 1,
group_size);
group_size, ctx);
} else {
std::vector<const Scalar*> a_array;
std::vector<const Scalar*> b_array;
@ -196,86 +210,48 @@ class BatchMatMulMkl : public OpKernel {
// pointer is to 2D matrix.
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array,
k_array, &a_array[0], lda_array, &b_array[0], ldb_array,
&c_array[0], ldc_array, 1, group_size);
&c_array[0], ldc_array, 1, group_size, ctx);
}
}
private:
bool adj_x_;
bool adj_y_;
BatchMatMulV2Op<CPUDevice, Scalar> eigen_batch_mm_v2_;
template <typename T,
typename std::enable_if<(std::is_same<T, float>::value ||
std::is_same<T, double>::value),
int>::type = 0>
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const std::vector<MKL_INT>& M_Array,
const std::vector<MKL_INT>& N_Array,
const std::vector<MKL_INT>& K_Array, const T** A_Array,
const std::vector<MKL_INT>& lda_Array,
const T** B_Array,
const std::vector<MKL_INT>& ldb_Array, T** C_Array,
const std::vector<MKL_INT>& ldc_Array,
const MKL_INT group_count,
const std::vector<MKL_INT>& group_size) {
void MklCblasGemmBatch(
const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB,
const std::vector<MKL_INT>& M_Array, const std::vector<MKL_INT>& N_Array,
const std::vector<MKL_INT>& K_Array, const float** A_Array,
const std::vector<MKL_INT>& lda_Array, const float** B_Array,
const std::vector<MKL_INT>& ldb_Array, float** C_Array,
const std::vector<MKL_INT>& ldc_Array, const MKL_INT group_count,
const std::vector<MKL_INT>& group_size, OpKernelContext* ctx) {
#ifndef ENABLE_MKLDNN_THREADPOOL
std::vector<CBLAS_TRANSPOSE> TransA_Array(
group_size[0], TransA ? CblasTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_Array(
group_size[0], TransB ? CblasTrans : CblasNoTrans);
if (std::is_same<T, float>::value) {
std::vector<float> alpha_Array(group_size[0], 1.0);
std::vector<float> beta_Array(group_size[0], 0.0);
cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0],
&N_Array[0], &K_Array[0], &alpha_Array[0],
reinterpret_cast<const float**>(A_Array), &lda_Array[0],
reinterpret_cast<const float**>(B_Array), &ldb_Array[0],
&beta_Array[0], reinterpret_cast<float**>(C_Array),
&ldc_Array[0], group_count, &group_size[0]);
} else {
std::vector<double> alpha_Array(group_size[0], 1.0);
std::vector<double> beta_Array(group_size[0], 0.0);
cblas_dgemm_batch(
Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0], &N_Array[0],
&K_Array[0], &alpha_Array[0],
reinterpret_cast<const double**>(A_Array), &lda_Array[0],
reinterpret_cast<const double**>(B_Array), &ldb_Array[0],
&beta_Array[0], reinterpret_cast<double**>(C_Array), &ldc_Array[0],
group_count, &group_size[0]);
}
std::vector<float> alpha_Array(group_size[0], 1.0);
std::vector<float> beta_Array(group_size[0], 0.0);
cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0],
&N_Array[0], &K_Array[0], &alpha_Array[0],
reinterpret_cast<const float**>(A_Array), &lda_Array[0],
reinterpret_cast<const float**>(B_Array), &ldb_Array[0],
&beta_Array[0], reinterpret_cast<float**>(C_Array),
&ldc_Array[0], group_count, &group_size[0]);
#else
DCHECK(Layout == CblasRowMajor);
std::vector<bool> TransA_Array(group_size[0], TransA);
std::vector<bool> TransB_Array(group_size[0], TransB);
std::vector<float> alpha_Array(group_size[0], 1.0);
std::vector<float> beta_Array(group_size[0], 0.0);
dnnl_gemm_batch<float>(TransA_Array, TransB_Array, M_Array, N_Array,
K_Array, alpha_Array, *A_Array, *B_Array, beta_Array,
*C_Array, group_count, group_size, ctx);
#endif // !ENABLE_MKLDNN_THREADPOOL
}
template <typename T,
typename std::enable_if<(std::is_same<T, complex64>::value ||
std::is_same<T, complex128>::value),
int>::type = 0>
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const std::vector<MKL_INT>& M_Array,
const std::vector<MKL_INT>& N_Array,
const std::vector<MKL_INT>& K_Array, const T** A_Array,
const std::vector<MKL_INT>& lda_Array,
const T** B_Array,
const std::vector<MKL_INT>& ldb_Array, T** C_Array,
const std::vector<MKL_INT>& ldc_Array,
const MKL_INT group_count,
const std::vector<MKL_INT>& group_size) {
std::vector<CBLAS_TRANSPOSE> TransA_array(
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_array(
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
std::vector<T> alpha_Array(group_size[0], {1.0f, 0.0f});
std::vector<T> beta_Array(group_size[0], {0.0f, 0.0f});
auto gemm_fn = (std::is_same<T, complex64>::value) ? cblas_cgemm_batch
: cblas_zgemm_batch;
gemm_fn(Layout, &TransA_array[0], &TransB_array[0], &M_Array[0],
&N_Array[0], &K_Array[0], static_cast<const void*>(&alpha_Array[0]),
reinterpret_cast<const void**>(A_Array), &lda_Array[0],
reinterpret_cast<const void**>(B_Array), &ldb_Array[0],
static_cast<const void*>(&beta_Array[0]),
reinterpret_cast<void**>(C_Array), &ldc_Array[0], group_count,
&group_size[0]);
}
// BatchMatMul BFloat16 support only exists in DNNL 1.2 onwards.
// BatchMatMul BFloat16 support only exists in DNNL 1.2 onwards.
#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16)
void MklCblasGemmBatch(
const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB,
@ -284,7 +260,7 @@ class BatchMatMulMkl : public OpKernel {
const std::vector<MKL_INT>& lda_Array, const bfloat16** B_Array,
const std::vector<MKL_INT>& ldb_Array, bfloat16** C_Array,
const std::vector<MKL_INT>& ldc_Array, const MKL_INT group_count,
const std::vector<MKL_INT>& group_size) {
const std::vector<MKL_INT>& group_size, OpKernelContext* ctx) {
DCHECK(Layout == CblasRowMajor);
std::vector<bool> TransA_Array(group_size[0], TransA);
std::vector<bool> TransB_Array(group_size[0], TransB);
@ -292,9 +268,9 @@ class BatchMatMulMkl : public OpKernel {
std::vector<float> beta_Array(group_size[0], 0.0);
// TODO(nhasabni): Remove *A when we pass a, b, and c correctly.
// MKLDNN API does not require lda, ldb, and ldc.
dnnl_gemm_batch<bfloat16>(TransA_Array, TransB_Array, M_Array, N_Array,
K_Array, alpha_Array, *A_Array, *B_Array,
beta_Array, *C_Array, group_count, group_size);
dnnl_gemm_batch<bfloat16>(
TransA_Array, TransB_Array, M_Array, N_Array, K_Array, alpha_Array,
*A_Array, *B_Array, beta_Array, *C_Array, group_count, group_size, ctx);
}
#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16
};
@ -315,13 +291,7 @@ class BatchMatMulMkl : public OpKernel {
#ifdef ENABLE_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL_V2);
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_MKL_V2);
#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16)
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2);

View File

@ -707,8 +707,8 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
const std::vector<int>& n, const std::vector<int>& k,
const std::vector<float>& alpha, const T* a, const T* b,
const std::vector<float>& beta, T* c,
const int group_count,
const std::vector<int>& group_size) {
const int group_count, const std::vector<int>& group_size,
OpKernelContext* ctx = nullptr) {
// Current BatchMatMul support in Tensorflow is narrower than the one offered
// by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1
// group of size equal to batch_size, and all MatMul parameters (m, n, k,

View File

@ -142,9 +142,7 @@ REGISTER_OP("_MklBatchMatMul")
.Input("x: T")
.Input("y: T")
.Output("output: T")
.Attr(
"T: {bfloat16, half, float, double, int32, int64, complex64, "
"complex128}")
.Attr("T: {bfloat16, float}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.SetShapeFn(shape_inference::BatchMatMulShape);
@ -153,9 +151,7 @@ REGISTER_OP("_MklBatchMatMulV2")
.Input("x: T")
.Input("y: T")
.Output("output: T")
.Attr(
"T: {bfloat16, half, float, double, int32, int64, complex64, "
"complex128}")
.Attr("T: {bfloat16, float}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.SetShapeFn(shape_inference::BatchMatMulV2Shape);