Merge pull request #40128 from Intel-tensorflow:sriniva2/tp_batch_matmul
PiperOrigin-RevId: 315324346 Change-Id: Ieb6918ad7ef5f9ac59773a15fc8eca9fa8c8ef16
This commit is contained in:
commit
6ceeae8697
@ -394,10 +394,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.batch_matmul,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
||||
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
|
||||
rinfo_.push_back({csinfo_.batch_matmul_v2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
||||
CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.concat, mkl_op_registry::GetMklOpName(csinfo_.concat),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY)
|
||||
#if defined(INTEL_MKL)
|
||||
#include <vector>
|
||||
|
||||
#include "mkl_cblas.h"
|
||||
@ -54,7 +54,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
template <typename Device, typename Scalar, bool v2_bcast>
|
||||
class BatchMatMulMkl : public OpKernel {
|
||||
public:
|
||||
explicit BatchMatMulMkl(OpKernelConstruction* context) : OpKernel(context) {
|
||||
explicit BatchMatMulMkl(OpKernelConstruction* context)
|
||||
: OpKernel(context), eigen_batch_mm_v2_(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
||||
}
|
||||
@ -104,6 +105,14 @@ class BatchMatMulMkl : public OpKernel {
|
||||
"In[0] and In[1] must have compatible batch dimensions: ",
|
||||
lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString()));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
if (bcast.IsBroadcastingRequired()) {
|
||||
// Calling Eigen Kernel for broadcasting case and return. Eigen does
|
||||
// not have BF16 support, so we have to fail graciously in that case.
|
||||
eigen_batch_mm_v2_.Compute(ctx);
|
||||
return;
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
TensorShape out_shape = bcast.output_batch_shape();
|
||||
auto batch_size = bcast.output_batch_size();
|
||||
|
||||
@ -149,22 +158,27 @@ class BatchMatMulMkl : public OpKernel {
|
||||
std::vector<MKL_INT> ldc_array(batch_size, N);
|
||||
std::vector<MKL_INT> group_size(1, batch_size);
|
||||
|
||||
if (std::is_same<Scalar, bfloat16>::value) {
|
||||
bool threadpool_enabled = false;
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
threadpool_enabled = true;
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
if (std::is_same<Scalar, bfloat16>::value || threadpool_enabled) {
|
||||
// DNNL bfloat16 API requires a, b, and c as pointers to tensors
|
||||
// represented as flat-byte array.
|
||||
const Scalar* a = nullptr;
|
||||
const Scalar* b = nullptr;
|
||||
OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(),
|
||||
errors::Unimplemented("Broadcasting is not supported for "
|
||||
"BFloat16 _MklBatchMatMul yet."));
|
||||
Scalar* c = nullptr;
|
||||
a = &lhs_reshaped(0, 0, 0);
|
||||
b = &rhs_reshaped(0, 0, 0);
|
||||
Scalar* c = &out_reshaped(0, 0, 0);
|
||||
OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(),
|
||||
errors::Unimplemented("Broadcasting is not supported for "
|
||||
"_MklBatchMatMul yet."));
|
||||
c = &out_reshaped(0, 0, 0);
|
||||
// TODO(nhasabni): Use appropriate cast instead of passing addresses of
|
||||
// a,b and c.
|
||||
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array,
|
||||
k_array, &a, lda_array, &b, ldb_array, &c, ldc_array, 1,
|
||||
group_size);
|
||||
group_size, ctx);
|
||||
} else {
|
||||
std::vector<const Scalar*> a_array;
|
||||
std::vector<const Scalar*> b_array;
|
||||
@ -196,86 +210,48 @@ class BatchMatMulMkl : public OpKernel {
|
||||
// pointer is to 2D matrix.
|
||||
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array,
|
||||
k_array, &a_array[0], lda_array, &b_array[0], ldb_array,
|
||||
&c_array[0], ldc_array, 1, group_size);
|
||||
&c_array[0], ldc_array, 1, group_size, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool adj_x_;
|
||||
bool adj_y_;
|
||||
BatchMatMulV2Op<CPUDevice, Scalar> eigen_batch_mm_v2_;
|
||||
|
||||
template <typename T,
|
||||
typename std::enable_if<(std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value),
|
||||
int>::type = 0>
|
||||
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
|
||||
const bool TransB, const std::vector<MKL_INT>& M_Array,
|
||||
const std::vector<MKL_INT>& N_Array,
|
||||
const std::vector<MKL_INT>& K_Array, const T** A_Array,
|
||||
const std::vector<MKL_INT>& lda_Array,
|
||||
const T** B_Array,
|
||||
const std::vector<MKL_INT>& ldb_Array, T** C_Array,
|
||||
const std::vector<MKL_INT>& ldc_Array,
|
||||
const MKL_INT group_count,
|
||||
const std::vector<MKL_INT>& group_size) {
|
||||
void MklCblasGemmBatch(
|
||||
const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB,
|
||||
const std::vector<MKL_INT>& M_Array, const std::vector<MKL_INT>& N_Array,
|
||||
const std::vector<MKL_INT>& K_Array, const float** A_Array,
|
||||
const std::vector<MKL_INT>& lda_Array, const float** B_Array,
|
||||
const std::vector<MKL_INT>& ldb_Array, float** C_Array,
|
||||
const std::vector<MKL_INT>& ldc_Array, const MKL_INT group_count,
|
||||
const std::vector<MKL_INT>& group_size, OpKernelContext* ctx) {
|
||||
#ifndef ENABLE_MKLDNN_THREADPOOL
|
||||
std::vector<CBLAS_TRANSPOSE> TransA_Array(
|
||||
group_size[0], TransA ? CblasTrans : CblasNoTrans);
|
||||
std::vector<CBLAS_TRANSPOSE> TransB_Array(
|
||||
group_size[0], TransB ? CblasTrans : CblasNoTrans);
|
||||
if (std::is_same<T, float>::value) {
|
||||
std::vector<float> alpha_Array(group_size[0], 1.0);
|
||||
std::vector<float> beta_Array(group_size[0], 0.0);
|
||||
cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0],
|
||||
&N_Array[0], &K_Array[0], &alpha_Array[0],
|
||||
reinterpret_cast<const float**>(A_Array), &lda_Array[0],
|
||||
reinterpret_cast<const float**>(B_Array), &ldb_Array[0],
|
||||
&beta_Array[0], reinterpret_cast<float**>(C_Array),
|
||||
&ldc_Array[0], group_count, &group_size[0]);
|
||||
} else {
|
||||
std::vector<double> alpha_Array(group_size[0], 1.0);
|
||||
std::vector<double> beta_Array(group_size[0], 0.0);
|
||||
cblas_dgemm_batch(
|
||||
Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0], &N_Array[0],
|
||||
&K_Array[0], &alpha_Array[0],
|
||||
reinterpret_cast<const double**>(A_Array), &lda_Array[0],
|
||||
reinterpret_cast<const double**>(B_Array), &ldb_Array[0],
|
||||
&beta_Array[0], reinterpret_cast<double**>(C_Array), &ldc_Array[0],
|
||||
group_count, &group_size[0]);
|
||||
}
|
||||
std::vector<float> alpha_Array(group_size[0], 1.0);
|
||||
std::vector<float> beta_Array(group_size[0], 0.0);
|
||||
cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0],
|
||||
&N_Array[0], &K_Array[0], &alpha_Array[0],
|
||||
reinterpret_cast<const float**>(A_Array), &lda_Array[0],
|
||||
reinterpret_cast<const float**>(B_Array), &ldb_Array[0],
|
||||
&beta_Array[0], reinterpret_cast<float**>(C_Array),
|
||||
&ldc_Array[0], group_count, &group_size[0]);
|
||||
#else
|
||||
DCHECK(Layout == CblasRowMajor);
|
||||
std::vector<bool> TransA_Array(group_size[0], TransA);
|
||||
std::vector<bool> TransB_Array(group_size[0], TransB);
|
||||
std::vector<float> alpha_Array(group_size[0], 1.0);
|
||||
std::vector<float> beta_Array(group_size[0], 0.0);
|
||||
dnnl_gemm_batch<float>(TransA_Array, TransB_Array, M_Array, N_Array,
|
||||
K_Array, alpha_Array, *A_Array, *B_Array, beta_Array,
|
||||
*C_Array, group_count, group_size, ctx);
|
||||
#endif // !ENABLE_MKLDNN_THREADPOOL
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename std::enable_if<(std::is_same<T, complex64>::value ||
|
||||
std::is_same<T, complex128>::value),
|
||||
int>::type = 0>
|
||||
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
|
||||
const bool TransB, const std::vector<MKL_INT>& M_Array,
|
||||
const std::vector<MKL_INT>& N_Array,
|
||||
const std::vector<MKL_INT>& K_Array, const T** A_Array,
|
||||
const std::vector<MKL_INT>& lda_Array,
|
||||
const T** B_Array,
|
||||
const std::vector<MKL_INT>& ldb_Array, T** C_Array,
|
||||
const std::vector<MKL_INT>& ldc_Array,
|
||||
const MKL_INT group_count,
|
||||
const std::vector<MKL_INT>& group_size) {
|
||||
std::vector<CBLAS_TRANSPOSE> TransA_array(
|
||||
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
|
||||
std::vector<CBLAS_TRANSPOSE> TransB_array(
|
||||
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
|
||||
std::vector<T> alpha_Array(group_size[0], {1.0f, 0.0f});
|
||||
std::vector<T> beta_Array(group_size[0], {0.0f, 0.0f});
|
||||
auto gemm_fn = (std::is_same<T, complex64>::value) ? cblas_cgemm_batch
|
||||
: cblas_zgemm_batch;
|
||||
gemm_fn(Layout, &TransA_array[0], &TransB_array[0], &M_Array[0],
|
||||
&N_Array[0], &K_Array[0], static_cast<const void*>(&alpha_Array[0]),
|
||||
reinterpret_cast<const void**>(A_Array), &lda_Array[0],
|
||||
reinterpret_cast<const void**>(B_Array), &ldb_Array[0],
|
||||
static_cast<const void*>(&beta_Array[0]),
|
||||
reinterpret_cast<void**>(C_Array), &ldc_Array[0], group_count,
|
||||
&group_size[0]);
|
||||
}
|
||||
|
||||
// BatchMatMul BFloat16 support only exists in DNNL 1.2 onwards.
|
||||
// BatchMatMul BFloat16 support only exists in DNNL 1.2 onwards.
|
||||
#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
void MklCblasGemmBatch(
|
||||
const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB,
|
||||
@ -284,7 +260,7 @@ class BatchMatMulMkl : public OpKernel {
|
||||
const std::vector<MKL_INT>& lda_Array, const bfloat16** B_Array,
|
||||
const std::vector<MKL_INT>& ldb_Array, bfloat16** C_Array,
|
||||
const std::vector<MKL_INT>& ldc_Array, const MKL_INT group_count,
|
||||
const std::vector<MKL_INT>& group_size) {
|
||||
const std::vector<MKL_INT>& group_size, OpKernelContext* ctx) {
|
||||
DCHECK(Layout == CblasRowMajor);
|
||||
std::vector<bool> TransA_Array(group_size[0], TransA);
|
||||
std::vector<bool> TransB_Array(group_size[0], TransB);
|
||||
@ -292,9 +268,9 @@ class BatchMatMulMkl : public OpKernel {
|
||||
std::vector<float> beta_Array(group_size[0], 0.0);
|
||||
// TODO(nhasabni): Remove *A when we pass a, b, and c correctly.
|
||||
// MKLDNN API does not require lda, ldb, and ldc.
|
||||
dnnl_gemm_batch<bfloat16>(TransA_Array, TransB_Array, M_Array, N_Array,
|
||||
K_Array, alpha_Array, *A_Array, *B_Array,
|
||||
beta_Array, *C_Array, group_count, group_size);
|
||||
dnnl_gemm_batch<bfloat16>(
|
||||
TransA_Array, TransB_Array, M_Array, N_Array, K_Array, alpha_Array,
|
||||
*A_Array, *B_Array, beta_Array, *C_Array, group_count, group_size, ctx);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16
|
||||
};
|
||||
@ -315,13 +291,7 @@ class BatchMatMulMkl : public OpKernel {
|
||||
|
||||
#ifdef ENABLE_MKL
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
|
||||
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_MKL);
|
||||
|
||||
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2);
|
||||
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL_V2);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_BATCH_MATMUL_MKL_V2);
|
||||
|
||||
#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL);
|
||||
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2);
|
||||
|
@ -707,8 +707,8 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
|
||||
const std::vector<int>& n, const std::vector<int>& k,
|
||||
const std::vector<float>& alpha, const T* a, const T* b,
|
||||
const std::vector<float>& beta, T* c,
|
||||
const int group_count,
|
||||
const std::vector<int>& group_size) {
|
||||
const int group_count, const std::vector<int>& group_size,
|
||||
OpKernelContext* ctx = nullptr) {
|
||||
// Current BatchMatMul support in Tensorflow is narrower than the one offered
|
||||
// by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1
|
||||
// group of size equal to batch_size, and all MatMul parameters (m, n, k,
|
||||
|
@ -142,9 +142,7 @@ REGISTER_OP("_MklBatchMatMul")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("output: T")
|
||||
.Attr(
|
||||
"T: {bfloat16, half, float, double, int32, int64, complex64, "
|
||||
"complex128}")
|
||||
.Attr("T: {bfloat16, float}")
|
||||
.Attr("adj_x: bool = false")
|
||||
.Attr("adj_y: bool = false")
|
||||
.SetShapeFn(shape_inference::BatchMatMulShape);
|
||||
@ -153,9 +151,7 @@ REGISTER_OP("_MklBatchMatMulV2")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("output: T")
|
||||
.Attr(
|
||||
"T: {bfloat16, half, float, double, int32, int64, complex64, "
|
||||
"complex128}")
|
||||
.Attr("T: {bfloat16, float}")
|
||||
.Attr("adj_x: bool = false")
|
||||
.Attr("adj_y: bool = false")
|
||||
.SetShapeFn(shape_inference::BatchMatMulV2Shape);
|
||||
|
Loading…
x
Reference in New Issue
Block a user