matmul,qmatmul and fusedops support for threadpool api.
This commit is contained in:
parent
735bb0fc23
commit
bcd0f459a1
tensorflow/core/kernels
@ -62,11 +62,11 @@ class MklMatMulOp : public OpKernel {
|
|||||||
dim_pair[0].first = transpose_a_ ? 0 : 1;
|
dim_pair[0].first = transpose_a_ ? 0 : 1;
|
||||||
dim_pair[0].second = transpose_b_ ? 1 : 0;
|
dim_pair[0].second = transpose_b_ ? 1 : 0;
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(ctx,
|
||||||
ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
|
a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
|
||||||
"Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
|
a.shape().DebugString(), ", In[1]: ",
|
||||||
", In[1]: ", b.shape().DebugString()));
|
b.shape().DebugString()));
|
||||||
int a_dim_remaining = 1 - dim_pair[0].first;
|
int a_dim_remaining = 1 - dim_pair[0].first;
|
||||||
int b_dim_remaining = 1 - dim_pair[0].second;
|
int b_dim_remaining = 1 - dim_pair[0].second;
|
||||||
TensorShape out_shape(
|
TensorShape out_shape(
|
||||||
@ -158,9 +158,17 @@ class MklMatMulOp : public OpKernel {
|
|||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
char char_transa = transa ? 'T' : 'N';
|
char char_transa = transa ? 'T' : 'N';
|
||||||
char char_transb = transb ? 'T' : 'N';
|
char char_transb = transb ? 'T' : 'N';
|
||||||
VLOG(2) << "MKL DNN SGEMM CALLED";
|
VLOG(2) << "MKL DNN SGEMM called";
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
auto eigen_tp =
|
||||||
|
MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx);
|
||||||
|
|
||||||
|
dnnl_sgemm_tp(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb,
|
||||||
|
beta, c, ldc, eigen_tp);
|
||||||
|
#else
|
||||||
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
||||||
c, ldc);
|
c, ldc);
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
#else
|
#else
|
||||||
// TODO(intel-tf): Remove this after TF2.3 fork.
|
// TODO(intel-tf): Remove this after TF2.3 fork.
|
||||||
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
|
||||||
@ -182,7 +190,7 @@ class MklMatMulOp : public OpKernel {
|
|||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
const char ftrans[] = {'N', 'T', 'C'};
|
const char ftrans[] = {'N', 'T', 'C'};
|
||||||
dnnl_gemm<bfloat16>(ftrans[index_transa], ftrans[index_transb], m, n, k,
|
dnnl_gemm<bfloat16>(ftrans[index_transa], ftrans[index_transb], m, n, k,
|
||||||
alpha, a, lda, b, ldb, beta, c, ldc);
|
alpha, a, lda, b, ldb, beta, c, ldc, ctx);
|
||||||
#else
|
#else
|
||||||
Tensor c_float;
|
Tensor c_float;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
|
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
|
||||||
|
@ -86,11 +86,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
const int k = src_tf_shape.dim_size(dim_pair[0]);
|
const int k = src_tf_shape.dim_size(dim_pair[0]);
|
||||||
const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]);
|
const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]);
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
|
||||||
ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
|
errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
|
||||||
errors::InvalidArgument(
|
src_tf_shape.DebugString(), ", In[1]: ",
|
||||||
"Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(),
|
weight_tf_shape.DebugString()));
|
||||||
", In[1]: ", weight_tf_shape.DebugString()));
|
|
||||||
OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel,
|
OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Must provide as many biases as the channel size: ",
|
"Must provide as many biases as the channel size: ",
|
||||||
@ -159,8 +158,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
|
|
||||||
if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) {
|
if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) {
|
||||||
src_mkl.SetUsrMem(src_md, src_data);
|
src_mkl.SetUsrMem(src_md, src_data);
|
||||||
src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
src_mkl.CheckReorderToOpMem(
|
||||||
matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
|
MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_SRC,
|
||||||
|
this->cpu_engine_),
|
||||||
|
ctx);
|
||||||
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,19 +192,23 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
weight_data = cached_weight_data;
|
weight_data = cached_weight_data;
|
||||||
} else {
|
} else {
|
||||||
weight_mkl.SetUsrMem(weight_md, weight_data);
|
weight_mkl.SetUsrMem(weight_md, weight_data);
|
||||||
weight_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
weight_mkl.CheckReorderToOpMem(
|
||||||
matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_));
|
MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS,
|
||||||
|
this->cpu_engine_),
|
||||||
|
ctx);
|
||||||
weight_data =
|
weight_data =
|
||||||
reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle());
|
reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
std::shared_ptr<stream> cpu_stream;
|
||||||
|
cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
|
||||||
// Execute fused matmul op.
|
// Execute fused matmul op.
|
||||||
matmul_prim->Execute(src_data, weight_data, bias_data, dst_data);
|
matmul_prim->Execute(src_data, weight_data, bias_data, dst_data,
|
||||||
|
cpu_stream);
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = "Status: " + std::to_string(e.status) +
|
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
||||||
", message: " + string(e.message) + ", in file " +
|
string(e.message) + ", in file " + string(__FILE__) +
|
||||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
":" + std::to_string(__LINE__);
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, errors::Aborted("Operation received an exception:", error_msg));
|
ctx, errors::Aborted("Operation received an exception:", error_msg));
|
||||||
}
|
}
|
||||||
|
@ -75,8 +75,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
public:
|
public:
|
||||||
explicit MklDnnMatMulFwdPrimitive(
|
explicit MklDnnMatMulFwdPrimitive(
|
||||||
const MklDnnMatMulFwdParams& matmulFwdParams)
|
const MklDnnMatMulFwdParams& matmulFwdParams)
|
||||||
: cpu_engine_(ENGINE_CPU, 0) {
|
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
|
||||||
// Create matmul primitive
|
// Create matmul primitive
|
||||||
if (context_.matmul_fwd == nullptr) {
|
if (context_.matmul_fwd == nullptr) {
|
||||||
Setup(matmulFwdParams);
|
Setup(matmulFwdParams);
|
||||||
@ -91,7 +90,18 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
// - bias_data: input data buffer of bias
|
// - bias_data: input data buffer of bias
|
||||||
// - dst_data: output data buffer of dst
|
// - dst_data: output data buffer of dst
|
||||||
void Execute(const Tinput* src_data, const Tweight* weight_data,
|
void Execute(const Tinput* src_data, const Tweight* weight_data,
|
||||||
const Tbias* bias_data, Toutput* dst_data) {
|
const Tbias* bias_data, Toutput* dst_data,
|
||||||
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
context_.src_mem->set_data_handle(
|
||||||
|
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
||||||
|
context_.weight_mem->set_data_handle(
|
||||||
|
static_cast<void*>(const_cast<Tweight*>(weight_data)), *fwd_stream);
|
||||||
|
context_.bias_mem->set_data_handle(
|
||||||
|
static_cast<void*>(const_cast<Tbias*>(bias_data)));
|
||||||
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
|
||||||
|
*fwd_stream);
|
||||||
|
#else
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
||||||
context_.weight_mem->set_data_handle(
|
context_.weight_mem->set_data_handle(
|
||||||
@ -99,12 +109,12 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
context_.bias_mem->set_data_handle(
|
context_.bias_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<Tbias*>(bias_data)));
|
static_cast<void*>(const_cast<Tbias*>(bias_data)));
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||||
context_.net_args);
|
|
||||||
#else
|
#else
|
||||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
fwd_stream->submit(context_.fwd_primitives);
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
// After execution, set data handle back
|
// After execution, set data handle back
|
||||||
@ -153,7 +163,6 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
|
|
||||||
// Inner-product primitive.
|
// Inner-product primitive.
|
||||||
std::shared_ptr<mkldnn::primitive> matmul_fwd;
|
std::shared_ptr<mkldnn::primitive> matmul_fwd;
|
||||||
std::shared_ptr<mkldnn::stream> fwd_stream;
|
|
||||||
std::vector<mkldnn::primitive> fwd_primitives;
|
std::vector<mkldnn::primitive> fwd_primitives;
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
@ -176,8 +185,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
weight_md(nullptr),
|
weight_md(nullptr),
|
||||||
bias_md(nullptr),
|
bias_md(nullptr),
|
||||||
dst_md(nullptr),
|
dst_md(nullptr),
|
||||||
matmul_fwd(nullptr),
|
matmul_fwd(nullptr) {
|
||||||
fwd_stream(nullptr) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -292,7 +300,6 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct MklDnnMatMulFwdContext context_;
|
struct MklDnnMatMulFwdContext context_;
|
||||||
engine cpu_engine_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename Tinput, typename Tweight, typename Tbias,
|
template <typename T, typename Tinput, typename Tweight, typename Tbias,
|
||||||
@ -439,8 +446,10 @@ class MklDnnMatMulOpBase : public OpKernel {
|
|||||||
|
|
||||||
// reorder and cache the weight
|
// reorder and cache the weight
|
||||||
weight.SetUsrMem(weight_md, &weight_tensor);
|
weight.SetUsrMem(weight_md, &weight_tensor);
|
||||||
weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
weight.CheckReorderToOpMem(
|
||||||
matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
|
MEMORY_PD_WITHOUT_DATA(matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS,
|
||||||
|
cpu_engine_),
|
||||||
|
context);
|
||||||
weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
|
weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
|
||||||
|
|
||||||
Tensor* weight_tensor_ptr = nullptr;
|
Tensor* weight_tensor_ptr = nullptr;
|
||||||
@ -544,21 +553,28 @@ template <typename T>
|
|||||||
class MklMatMulPrimitive : public MklPrimitive {
|
class MklMatMulPrimitive : public MklPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit MklMatMulPrimitive(const MklMatMulParams& params)
|
explicit MklMatMulPrimitive(const MklMatMulParams& params)
|
||||||
: cpu_engine_(ENGINE_CPU, 0) {
|
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||||
context_.stream.reset(new CPU_STREAM(cpu_engine_));
|
|
||||||
// Create matmul primitive
|
// Create matmul primitive
|
||||||
Setup(params);
|
Setup(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
~MklMatMulPrimitive() {}
|
~MklMatMulPrimitive() {}
|
||||||
|
|
||||||
void Execute(const T* a_data, const T* b_data, T* c_data) {
|
void Execute(const T* a_data, const T* b_data, T* c_data,
|
||||||
|
std::shared_ptr<stream> stream) {
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)),
|
||||||
|
*stream);
|
||||||
|
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)),
|
||||||
|
*stream);
|
||||||
|
context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)),
|
||||||
|
*stream);
|
||||||
|
#else
|
||||||
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
|
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
|
||||||
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
|
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
|
||||||
context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
|
context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
execute_primitives(context_.matmul_primitives, context_.stream,
|
execute_primitives(context_.matmul_primitives, stream, context_.net_args);
|
||||||
context_.net_args);
|
|
||||||
|
|
||||||
// After execution, set data handle back
|
// After execution, set data handle back
|
||||||
context_.a_mem->set_data_handle(DummyData);
|
context_.a_mem->set_data_handle(DummyData);
|
||||||
@ -584,7 +600,6 @@ class MklMatMulPrimitive : public MklPrimitive {
|
|||||||
std::shared_ptr<mkldnn::memory::desc> c_md;
|
std::shared_ptr<mkldnn::memory::desc> c_md;
|
||||||
|
|
||||||
// MatMul primitive.
|
// MatMul primitive.
|
||||||
std::shared_ptr<mkldnn::stream> stream;
|
|
||||||
std::vector<mkldnn::primitive> matmul_primitives;
|
std::vector<mkldnn::primitive> matmul_primitives;
|
||||||
std::vector<std::unordered_map<int, memory>> net_args;
|
std::vector<std::unordered_map<int, memory>> net_args;
|
||||||
|
|
||||||
@ -596,8 +611,7 @@ class MklMatMulPrimitive : public MklPrimitive {
|
|||||||
prim_desc(nullptr),
|
prim_desc(nullptr),
|
||||||
a_md(nullptr),
|
a_md(nullptr),
|
||||||
b_md(nullptr),
|
b_md(nullptr),
|
||||||
c_md(nullptr),
|
c_md(nullptr) {}
|
||||||
stream(nullptr) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void Setup(const MklMatMulParams& params) {
|
void Setup(const MklMatMulParams& params) {
|
||||||
@ -639,7 +653,6 @@ class MklMatMulPrimitive : public MklPrimitive {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct MklMatMulContext context_;
|
struct MklMatMulContext context_;
|
||||||
engine cpu_engine_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -707,8 +720,8 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
|
|||||||
const std::vector<int>& n, const std::vector<int>& k,
|
const std::vector<int>& n, const std::vector<int>& k,
|
||||||
const std::vector<float>& alpha, const T* a, const T* b,
|
const std::vector<float>& alpha, const T* a, const T* b,
|
||||||
const std::vector<float>& beta, T* c,
|
const std::vector<float>& beta, T* c,
|
||||||
const int group_count,
|
const int group_count, const std::vector<int>& group_size,
|
||||||
const std::vector<int>& group_size) {
|
OpKernelContext* ctx = nullptr) {
|
||||||
// Current BatchMatMul support in Tensorflow is narrower than the one offered
|
// Current BatchMatMul support in Tensorflow is narrower than the one offered
|
||||||
// by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1
|
// by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1
|
||||||
// group of size equal to batch_size, and all MatMul parameters (m, n, k,
|
// group of size equal to batch_size, and all MatMul parameters (m, n, k,
|
||||||
@ -757,13 +770,15 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
|
|||||||
MklMatMulPrimitiveFactory<T>::Get(params, 0);
|
MklMatMulPrimitiveFactory<T>::Get(params, 0);
|
||||||
|
|
||||||
// Execute matmul primitive.
|
// Execute matmul primitive.
|
||||||
matmul_prim->Execute(a, b, c);
|
std::shared_ptr<stream> cpu_stream;
|
||||||
|
cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
|
||||||
|
matmul_prim->Execute(a, b, c, cpu_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
|
void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
|
||||||
float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
|
float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
|
||||||
float beta, T* c, int64_t ldc) {
|
float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) {
|
||||||
using dims = mkldnn::memory::dims;
|
using dims = mkldnn::memory::dims;
|
||||||
|
|
||||||
// Prepare strides based on the transa and transb flags: transposed
|
// Prepare strides based on the transa and transb flags: transposed
|
||||||
@ -786,7 +801,9 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
|
|||||||
MklMatMulPrimitiveFactory<T>::Get(params, 0);
|
MklMatMulPrimitiveFactory<T>::Get(params, 0);
|
||||||
|
|
||||||
// Execute matmul primitive.
|
// Execute matmul primitive.
|
||||||
matmul_prim->Execute(a, b, c);
|
std::shared_ptr<stream> cpu_stream;
|
||||||
|
cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine()));
|
||||||
|
matmul_prim->Execute(a, b, c, cpu_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
@ -245,8 +245,10 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
Tinput* src_data = nullptr;
|
Tinput* src_data = nullptr;
|
||||||
if (IS_SRC_REORDER_NEEDED(src_md, matmul_fwd_pd, matmul_fwd)) {
|
if (IS_SRC_REORDER_NEEDED(src_md, matmul_fwd_pd, matmul_fwd)) {
|
||||||
src.SetUsrMem(src_md, &src_tensor);
|
src.SetUsrMem(src_md, &src_tensor);
|
||||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
src.CheckReorderToOpMem(
|
||||||
matmul_fwd_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_));
|
MEMORY_PD_WITHOUT_DATA(matmul_fwd_pd.get()->PRIMITIVE_DESC_SRC,
|
||||||
|
this->cpu_engine_),
|
||||||
|
context);
|
||||||
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
||||||
} else {
|
} else {
|
||||||
src_data = static_cast<Tinput*>(
|
src_data = static_cast<Tinput*>(
|
||||||
@ -279,8 +281,11 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
|
|
||||||
if (!is_weight_cached) {
|
if (!is_weight_cached) {
|
||||||
weight.SetUsrMem(weight_md, &weight_tensor);
|
weight.SetUsrMem(weight_md, &weight_tensor);
|
||||||
weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
weight.CheckReorderToOpMem(
|
||||||
matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_));
|
MEMORY_PD_WITHOUT_DATA(
|
||||||
|
matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS,
|
||||||
|
this->cpu_engine_),
|
||||||
|
context);
|
||||||
weight_data =
|
weight_data =
|
||||||
static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
|
static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
|
||||||
}
|
}
|
||||||
@ -290,10 +295,13 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
const_cast<Tweight*>(weight_tensor.flat<Tweight>().data()));
|
const_cast<Tweight*>(weight_tensor.flat<Tweight>().data()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<stream> cpu_stream;
|
||||||
|
cpu_stream.reset(CreateStream(context, matmul_fwd->GetEngine()));
|
||||||
// Execute inner-product
|
// Execute inner-product
|
||||||
Tbias* bias_data = this->GetBiasHandle(context, matmul_fwd_pd,
|
Tbias* bias_data = this->GetBiasHandle(
|
||||||
bias_tensor, weight_tensor);
|
context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream);
|
||||||
matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data);
|
matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data,
|
||||||
|
cpu_stream);
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = tensorflow::strings::StrCat(
|
string error_msg = tensorflow::strings::StrCat(
|
||||||
"Status: ", e.status, ", message: ", string(e.message), ", in file ",
|
"Status: ", e.status, ", message: ", string(e.message), ", in file ",
|
||||||
@ -393,7 +401,8 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
OpKernelContext* context,
|
OpKernelContext* context,
|
||||||
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>&
|
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>&
|
||||||
mkldnn_matmul_fwd_pd,
|
mkldnn_matmul_fwd_pd,
|
||||||
const Tensor& bias_tensor, const Tensor& weight_tensor) {
|
const Tensor& bias_tensor, const Tensor& weight_tensor,
|
||||||
|
std::shared_ptr<stream> reorder_stream) {
|
||||||
// If the bias is qint32, it means the bias is already converted offline.
|
// If the bias is qint32, it means the bias is already converted offline.
|
||||||
// and it can be added to matmul output directly.
|
// and it can be added to matmul output directly.
|
||||||
if (std::is_same<Tbias, qint32>::value) {
|
if (std::is_same<Tbias, qint32>::value) {
|
||||||
@ -449,7 +458,6 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
std::vector<float> scales;
|
std::vector<float> scales;
|
||||||
scales.push_back(out_scale);
|
scales.push_back(out_scale);
|
||||||
mkldnn::primitive_attr bias_attr;
|
mkldnn::primitive_attr bias_attr;
|
||||||
stream reorder_stream = CPU_STREAM(this->cpu_engine_);
|
|
||||||
bias_attr.set_output_scales(0, scales);
|
bias_attr.set_output_scales(0, scales);
|
||||||
|
|
||||||
void* bias_buf = static_cast<void*>(
|
void* bias_buf = static_cast<void*>(
|
||||||
@ -468,14 +476,14 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
{MKLDNN_ARG_FROM, *input_bias_},
|
{MKLDNN_ARG_FROM, *input_bias_},
|
||||||
{ MKLDNN_ARG_TO,
|
{ MKLDNN_ARG_TO,
|
||||||
*scaled_bias_ }};
|
*scaled_bias_ }};
|
||||||
net.at(0).execute(reorder_stream, reorder_net_args);
|
net.at(0).execute(*reorder_stream, reorder_net_args);
|
||||||
#else
|
#else
|
||||||
auto reorder_desc = mkldnn::reorder::primitive_desc(
|
auto reorder_desc = mkldnn::reorder::primitive_desc(
|
||||||
input_bias_->get_primitive_desc(),
|
input_bias_->get_primitive_desc(),
|
||||||
scaled_bias_->get_primitive_desc(), bias_attr);
|
scaled_bias_->get_primitive_desc(), bias_attr);
|
||||||
net.push_back(
|
net.push_back(
|
||||||
mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_));
|
mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_));
|
||||||
reorder_stream.submit(net).wait();
|
reorder_stream->submit(net).wait();
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
|
return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
|
||||||
|
Loading…
Reference in New Issue
Block a user