Merge pull request #31782 from Intel-tensorflow:tenglu/fuse_fp32_matmul

PiperOrigin-RevId: 270955757
This commit is contained in:
TensorFlower Gardener 2019-09-24 12:22:23 -07:00
commit 8e3f8ae473
10 changed files with 803 additions and 349 deletions

View File

@ -273,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
csinfo_.fused_conv2d = "_FusedConv2D";
csinfo_.fused_matmul = "_FusedMatMul";
csinfo_.identity = "Identity";
csinfo_.leakyrelu = "LeakyRelu";
csinfo_.leakyrelu_grad = "LeakyReluGrad";
@ -294,6 +295,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_depthwise_conv2d_grad_filter =
"_MklDepthwiseConv2dNativeBackpropFilter";
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
csinfo_.pad = "Pad";
@ -478,6 +480,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
CopyAttrsFusedConv2D, FusedConv2DRewrite,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
CopyAttrsAll, FusedMatMulRewrite});
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity),
@ -905,6 +910,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
string fused_batch_norm_v3;
string fused_batch_norm_grad_v3;
string fused_conv2d;
string fused_matmul;
string identity;
string leakyrelu;
string leakyrelu_grad;
@ -924,6 +930,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
string mkl_depthwise_conv2d_grad_input;
string mkl_depthwise_conv2d_grad_filter;
string mkl_fused_conv2d;
string mkl_fused_matmul;
string mkl_pad_with_conv2d;
string mkl_pad_with_fused_conv2d;
string mul;
@ -1453,6 +1460,22 @@ rinfo_.push_back({csinfo_.tanh_grad,
return false;
}
// Rewrite rule for _FusedMatMul.
// @return - true (no transpose attribute for input 1 and only has 1 post op);
// false otherwise.
static bool FusedMatMulRewrite(const Node* n) {
bool trans_a;
std::vector<string> fused_ops;
// Do not rewrite with transpose attribute because reorder has performance
// impact.
TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a));
// Do not rewrite with more than 1 post op because MKL-DNN doesn't support.
TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
return (!trans_a) && (fused_ops.size() == 1);
}
// Check if we are performing pooling on depth or batch. If it is, then we
// do not rewrite MaxPool node to Mkl version.
// @return - true (if it is not a depth/batch wise pooling case);
@ -3553,6 +3576,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
n->type_string() != csinfo_.pad_with_fused_conv2d &&
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
n->type_string() != csinfo_.fused_conv2d &&
n->type_string() != csinfo_.fused_matmul &&
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
T)) {
return nullptr;

View File

@ -1874,6 +1874,53 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Negative2) {
"D(_FusedConv2D);E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E");
}
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
TEST_F(MklLayoutPassTest, NodeRewrite_FusedMatMul_Postive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedMatMul'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'transpose_a' value { b: false } }"
" attr { key: 'transpose_b' value { b: false } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklFusedMatMul);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);Z(Zeta)"
"|A->D;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->Z:1;D->Z;DMT/_0->D:3;"
"DMT/_1->D:4;DMT/_2->D:5");
}
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
TEST_F(MklLayoutPassTest, NodeRewrite_FusedMatMul_Negative) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedMatMul'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'transpose_a' value { b: true } }"
" attr { key: 'transpose_b' value { b: false } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_FusedMatMul);Z(Zeta)"
"|A->D;B->D:1;C->D:2;C->Z:1;D->Z");
}
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,

View File

@ -263,13 +263,7 @@ bool IsGpuCompatibleConv2D(const NodeDef* conv2d) {
bool IsCpuCompatibleMatMul(const NodeDef* matmul) {
DCHECK(IsMatMul(*matmul)) << "Expected MatMul op";
#ifndef INTEL_MKL
// Temporarily disable Matmul fusions if MKL is enabled.
// TODO(Intel) renable Matmul fusions when enabled by MKL DNN.
return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
#else
return false;
#endif // !INTEL_MKL
}
// Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.

View File

@ -3825,7 +3825,11 @@ tf_kernel_library(
tf_mkl_kernel_library(
name = "mkl_matmul_op",
srcs = ["mkl_matmul_op.cc"],
srcs = [
"mkl_matmul_op.cc",
"mkl_matmul_op_fused.cc",
],
hdrs = ["mkl_matmul_ops_common.h"],
deps = MATH_DEPS + mkl_deps(),
)
@ -7546,6 +7550,7 @@ tf_mkl_kernel_library(
name = "mkl_qmatmul_op",
srcs = ["mkl_qmatmul_op.cc"],
hdrs = [
"mkl_matmul_ops_common.h",
"mkl_quantized_conv_ops.h",
"no_op.h",
],
@ -7860,6 +7865,7 @@ tf_cc_test_mkl(
":conv_ops",
":image",
":mkl_conv_op",
":mkl_matmul_op",
":mkl_tfconv_op",
":ops_testutil",
":ops_util",

View File

@ -160,6 +160,33 @@ class CommonTestUtilities : public OpsTestBase {
test::ExpectClose(conv_2d, fused_conv_2d, 1e-5);
}
static void VerifyFusedMatrixClose(int depth, int batch, int weight_count,
const std::vector<string>& fused_ops,
const FusedGraphRunner& run_default,
const FusedGraphRunner& run_fused) {
DataType dtype = DataTypeToEnum<T>::v();
Tensor input(dtype, {batch, depth});
input.flat<T>() = input.flat<T>().setRandom();
Tensor weight(dtype, {depth, weight_count});
weight.flat<T>() = weight.flat<T>().setRandom();
Tensor bias(dtype, {weight_count});
bias.flat<T>() = bias.flat<T>().setRandom();
Tensor output;
Tensor fused_output;
run_default(input, weight, bias, fused_ops, &output);
run_fused(input, weight, bias, fused_ops, &fused_output);
ASSERT_EQ(output.dtype(), fused_output.dtype());
ASSERT_EQ(output.shape(), fused_output.shape());
test::ExpectClose(output, fused_output, 1e-5);
}
};
// Testing MKL's fused convolution ops
@ -586,6 +613,97 @@ TEST_F(FilterCacheTest, Conv2DFilterCacheTest) {
Run<float>(DT_FLOAT, image, filter, expected, true);
}
// Testing fusion of MatMul and BiasAdd
template <typename T>
class MklFusedMatMulOpTest : public OpsTestBase {
protected:
void VerifyFusedMatMul(const int kBatch, const int kInputChannel,
const int kOutputChannel) {
const FusedGraphRunner run_default =
[this](const Tensor& input, const Tensor& weight, const Tensor& bias,
const std::vector<string>& fused_ops, Tensor* output) {
auto root = tensorflow::Scope::NewRootScope();
auto input_op =
ops::Const(root.WithOpName("input"), Input::Initializer(input));
Output next_op = ops::MatMul(root.WithOpName("matmul"), input_op,
ops::Const(root.WithOpName("weight"),
Input::Initializer(weight)));
string last_op = "";
if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") !=
fused_ops.end()) {
last_op = "with_bias";
next_op = ops::BiasAdd(
root.WithOpName(last_op), next_op,
ops::Const(root.WithOpName("bias"), Input::Initializer(bias)));
}
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
};
const FusedGraphRunner run_fused =
[this](const Tensor& input, const Tensor& weight, const Tensor& bias,
const std::vector<string>& fused_ops, Tensor* output) {
DataType dtype = DataTypeToEnum<T>::v();
const int num_args = fused_ops.size();
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(num_args, DT_UINT8))
.Attr("T", dtype)
.Attr("transpose_a", false)
.Attr("transpose_b", false)
.Attr("num_args", num_args)
.Attr("fused_ops", fused_ops)
.Attr("epsilon", 0.0001)
.Attr("_kernel", "MklLayoutDependentOp")
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(input.shape(), input.flat<T>());
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
AddInputFromArray<T>(bias.shape(), bias.flat<T>());
// Add MKL meta input for input, filter and bias.
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_tensor = *GetOutput(0);
const Tensor& output_meta_tensor = *GetOutput(1);
CommonTestUtilities<T> test_util;
test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
output);
};
CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch,
kOutputChannel, {"BiasAdd"},
run_default, run_fused);
}
};
TYPED_TEST_CASE_P(MklFusedMatMulOpTest);
TYPED_TEST_P(MklFusedMatMulOpTest, BasicTest) {
const int batch = 3;
const int input_channel = 4;
const int output_channel = 5;
this->VerifyFusedMatMul(batch, input_channel, output_channel);
}
REGISTER_TYPED_TEST_CASE_P(MklFusedMatMulOpTest, BasicTest);
using MklFusedMatMulDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedMatMulOpTest,
MklFusedMatMulDataTypes);
class BiasCacheTest : public OpsTestBase {
public:
template <typename T>
@ -715,6 +833,7 @@ TEST_F(BiasCacheTest, Conv2DBiasCacheTest) {
Run<float>(DT_QUINT8, image, filter, bias, min_input, max_input, min_filter,
max_filter, min_output, max_output, expected, true);
}
// Testing fusion of pad and fusedconv2d
template <typename T>
class MklPadWithFusedConv2DOpTest : public OpsTestBase {

View File

@ -0,0 +1,193 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/math_ops.cc.
// This file uses MKL-DNN InnerProduct for acceleration of TF Matrix-Matrix
// Multiplication (MatMul) with bias (BiasAdd) operations.
#ifdef INTEL_MKL
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
// Fuse Operation
template <typename Device, typename T>
class MklFusedMatMulOp : public MklDnnMatMulOpBase<T> {
public:
explicit MklFusedMatMulOp(OpKernelConstruction* ctx)
: MklDnnMatMulOpBase<T>(ctx) {
std::vector<string> fused_ops;
OP_REQUIRES_OK(ctx, ctx->GetAttr("fused_ops", &fused_ops));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
OP_REQUIRES(ctx, fused_ops.size() == 1,
errors::InvalidArgument(
"MklFusedMatMul must have only one argument: bias."));
OP_REQUIRES(
ctx, transpose_a_ == false,
errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
}
void Compute(OpKernelContext* ctx) override {
// FusedMatMul has 3 inputs: src, weights, bias
const Tensor& src_tensor = ctx->input(this->kInputIndexSrc);
const Tensor& weight_tensor = ctx->input(this->kInputIndexWeight);
const Tensor& bias_tensor = MklGetInput(ctx, this->kInputIndexBias);
MklDnnShape src_mkl_shape;
MklDnnShape weight_mkl_shape;
GetMklShape(ctx, this->kInputIndexSrc, &src_mkl_shape);
GetMklShape(ctx, this->kInputIndexWeight, &weight_mkl_shape);
// Get shapes of input tensors
auto src_tf_shape = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape()
: src_tensor.shape();
auto weight_tf_shape = weight_mkl_shape.IsMklTensor()
? weight_mkl_shape.GetTfShape()
: weight_tensor.shape();
// Check the constraint of input matrix and bias
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(src_tf_shape),
errors::InvalidArgument("In[0] is not a matrix"));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(weight_tf_shape),
errors::InvalidArgument("In[1] is not a matrix"));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias_tensor.shape()),
errors::InvalidArgument("Biases must be 1D"));
// Expression: [batch, k] * [k, channel] + [channel] = [batch, channel]
//
// Get dimension size of each matrix, dim_pair[] is the location of k
// in the inputs, we have constraint that k of the two inputs are
// the same
const int dim_pair[] = {1, transpose_b_ ? 1 : 0};
const int batch = src_tf_shape.dim_size(1 - dim_pair[0]);
const int k = src_tf_shape.dim_size(dim_pair[0]);
const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]);
OP_REQUIRES(
ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
errors::InvalidArgument(
"Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(),
", In[1]: ", weight_tf_shape.DebugString()));
OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel,
errors::InvalidArgument(
"Must provide as many biases as the channel size: ",
bias_tensor.shape().DebugString(), " vs. ", channel));
// For inputs s[batch, k], w[k, channel] and b[channel], the primitive
// dims should be described like this:
// s[batch, k] * w^T[channel, k] + b[channel] = dst[batch, channel]
// [n, ic] * [oc, ic] + [oc] = [n, oc]
memory::dims src_dims = memory::dims({batch, k});
// Reverse the weights dims from [k, channel] to [channel, k].
memory::dims weight_dims = memory::dims({channel, k});
memory::dims bias_dims = memory::dims({channel});
memory::dims dst_dims = memory::dims({batch, channel});
memory::format weight_format =
transpose_b_ ? memory::format::oi : memory::format::io;
MklDnnMatMulFwdParams matmul_params(src_dims, weight_dims, bias_dims,
dst_dims, weight_format);
MklDnnMatMulFwdPrimitive<T, T, T, T, T>* matmul_prim =
MklDnnMatMulFwdPrimitiveFactory<T, T, T, T, T>::Get(matmul_params, 0);
// Allocate output tensor.
Tensor* dst_tensor = nullptr;
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd =
matmul_prim->GetPrimitiveDesc();
if (src_mkl_shape.IsMklTensor() && weight_mkl_shape.IsMklTensor()) {
this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims, memory::format::nc,
&dst_tensor);
} else {
TensorShape dst_tensor_shape({batch, channel});
MklDnnShape dst_mkl_shape;
dst_mkl_shape.SetMklTensor(false);
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape,
dst_mkl_shape);
}
// if there's nothing to compute, just return.
if (batch == 0 || channel == 0) {
return;
}
try {
// Prepare the input and output for primitive.
T* src_data = const_cast<T*>(src_tensor.flat<T>().data());
T* weight_data = const_cast<T*>(weight_tensor.flat<T>().data());
T* bias_data = const_cast<T*>(bias_tensor.flat<T>().data());
T* dst_data = const_cast<T*>(dst_tensor->flat<T>().data());
// Any input is MKL format, reorder it if necessary.
MklDnnData<T> src_mkl(&(this->cpu_engine_));
MklDnnData<T> weight_mkl(&(this->cpu_engine_));
if (src_mkl_shape.IsMklTensor()) {
memory::desc input_md = src_mkl_shape.GetMklLayout();
if (input_md.data.format != memory::format::nc) {
src_mkl.SetUsrMem(input_md, src_data);
src_mkl.CheckReorderToOpMem(matmul_pd.get()->src_primitive_desc());
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
}
}
if (weight_mkl_shape.IsMklTensor()) {
memory::desc input_md = weight_mkl_shape.GetMklLayout();
if (input_md.data.format != weight_format) {
weight_mkl.SetUsrMem(input_md, weight_data);
weight_mkl.CheckReorderToOpMem(
matmul_pd.get()->weights_primitive_desc());
weight_data =
reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle());
}
}
matmul_prim->Execute(src_data, weight_data, bias_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
ctx, errors::Aborted("Operation received an exception:", error_msg));
}
}
private:
bool transpose_a_;
bool transpose_b_;
};
// Register mkl kernels for supported operations and types.
#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \
REGISTER_KERNEL_BUILDER( \
Name("_MklFusedMatMul") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklFusedMatMulOp<CPUDevice, type>);
TF_CALL_float(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES);
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -0,0 +1,372 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MKL_MATMUL_OPS_COMMON_H_
#define TENSORFLOW_CORE_KERNELS_MKL_MATMUL_OPS_COMMON_H_
#ifdef INTEL_MKL
#include <memory>
#include <string>
#include <vector>
#include "mkldnn.hpp"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/mkl_util.h"
using mkldnn::inner_product_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
// This structure aggregates multiple inputs to MklDnnMatMul* methods.
struct MklDnnMatMulFwdParams {
memory::dims src_dims;
memory::dims weight_dims;
memory::dims bias_dims;
memory::dims dst_dims;
memory::format weight_fmt;
string dtypes = string("");
struct PostOpParam {
string name;
std::vector<float> param;
};
std::vector<PostOpParam> post_op_params;
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
memory::dims bias_dims, memory::dims dst_dims,
memory::format weight_fmt = memory::format::any)
: src_dims(src_dims),
weight_dims(weight_dims),
bias_dims(bias_dims),
dst_dims(dst_dims),
weight_fmt(weight_fmt) {}
};
// With quantization, input, weight, bias, and output can have different types.
// So we use different template parameters for each type.
// TODO(intel-tf): The template type "T" is currently used to match the
// templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
// In the future, with the removal of "T" from MklPrimitiveFactory, this class
// needs to drop "T".
template <typename T, typename Tinput, typename Tweight, typename Tbias,
typename Toutput>
class MklDnnMatMulFwdPrimitive : public MklPrimitive {
public:
explicit MklDnnMatMulFwdPrimitive(
const MklDnnMatMulFwdParams& matmulFwdParams)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// Create matmul primitive
if (context_.matmul_fwd == nullptr) {
Setup(matmulFwdParams);
}
}
~MklDnnMatMulFwdPrimitive() {}
// Inner-product forward execute with bias:
// - src_data: input data buffer of src
// - weight_data: input data buffer of weight
// - bias_data: input data buffer of bias
// - dst_data: output data buffer of dst
void Execute(const Tinput* src_data, const Tweight* weight_data,
const Tbias* bias_data, Toutput* dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data)));
context_.weight_mem->set_data_handle(
static_cast<void*>(const_cast<Tweight*>(weight_data)));
context_.bias_mem->set_data_handle(
static_cast<void*>(const_cast<Tbias*>(bias_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
context_.fwd_stream->submit(context_.fwd_primitives);
// After execution, set data handle back
context_.src_mem->set_data_handle(DummyData);
context_.weight_mem->set_data_handle(DummyData);
context_.bias_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
}
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
memory::format GetweightMemoryFormat() const { return context_.weight_fmt; }
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>
GetPrimitiveDesc() const {
return context_.fwd_pd;
}
private:
// Primitive reuse context for inner-product Fwd op
struct MklDnnMatMulFwdContext {
// Expected memory format for this primitive instance
memory::format src_fmt;
memory::format weight_fmt;
// MKL-DNN memory
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> weight_mem;
std::shared_ptr<mkldnn::memory> bias_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
// Descriptor and primitive-descriptor for forward inner-product
std::shared_ptr<mkldnn::inner_product_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> fwd_pd;
// Memory descriptors
std::shared_ptr<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> weight_md;
std::shared_ptr<mkldnn::memory::desc> bias_md;
std::shared_ptr<mkldnn::memory::desc> dst_md;
// Inner-product primitive
std::shared_ptr<mkldnn::primitive> matmul_fwd;
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
MklDnnMatMulFwdContext()
: src_fmt(memory::format::any),
weight_fmt(memory::format::any),
src_mem(nullptr),
weight_mem(nullptr),
bias_mem(nullptr),
dst_mem(nullptr),
fwd_desc(nullptr),
fwd_pd(nullptr),
src_md(nullptr),
weight_md(nullptr),
bias_md(nullptr),
dst_md(nullptr),
matmul_fwd(nullptr),
fwd_stream(nullptr) {}
};
void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) {
// Create memory descriptors for inner-product data with no specified format
context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
MklDnnType<Tinput>(),
memory::format::any));
context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
MklDnnType<Tweight>(),
matmul_fwd_params.weight_fmt));
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
MklDnnType<Toutput>(),
memory::format::any));
context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
MklDnnType<Tbias>(),
memory::format::any));
// Create an inner-product
context_.fwd_desc.reset(new inner_product_forward::desc(
prop_kind::forward_inference, *context_.src_md, *context_.weight_md,
*context_.bias_md, *context_.dst_md));
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
// Check if there is any fusion as post-ops
auto const& post_op_params = matmul_fwd_params.post_op_params;
mkldnn::primitive_attr post_ops_attr;
mkldnn::post_ops post_ops;
if (!post_op_params.empty()) {
for (auto const& post_op_param : post_op_params) {
if (post_op_param.name == "relu") {
DCHECK_EQ(post_op_param.param.size(), 3);
float op_scale = post_op_param.param[0];
float op_alpha = post_op_param.param[1];
float op_beta = post_op_param.param[2];
post_ops.append_eltwise(op_scale, mkldnn::eltwise_relu, op_alpha,
op_beta);
} else if (post_op_param.name == "output_scale") {
DCHECK_EQ(post_op_param.param.size(), 1);
std::vector<float> scales;
scales.push_back(post_op_param.param[0]);
post_ops_attr.set_output_scales(0, scales);
} else {
DCHECK((post_op_param.name == "relu") ||
(post_op_param.name == "output_scale"));
}
}
post_ops_attr.set_post_ops(post_ops);
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
*context_.fwd_desc, post_ops_attr, cpu_engine_));
} else {
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
}
// Store the expected memory format
context_.src_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
context_.weight_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
// Create memory primitive based on dummy data
context_.src_mem.reset(
new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
context_.weight_mem.reset(
new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
context_.dst_mem.reset(
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
context_.bias_mem.reset(new memory({{{matmul_fwd_params.bias_dims},
MklDnnType<Tbias>(),
memory::format::x},
cpu_engine_},
DummyData));
// Create inner-product primitive
context_.matmul_fwd.reset(new inner_product_forward(
*context_.fwd_pd, *context_.src_mem, *context_.weight_mem,
*context_.bias_mem, *context_.dst_mem));
context_.fwd_primitives.push_back(*context_.matmul_fwd);
return;
}
struct MklDnnMatMulFwdContext context_;
engine cpu_engine_;
};
template <typename T, typename Tinput, typename Tweight, typename Tbias,
typename Toutput>
class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get(
const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) {
MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd =
nullptr;
if (do_not_cache) {
// Always create new primitive
matmul_fwd =
new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
mkldnn_matmul_fwd_dims);
} else {
// Try to find a suitable one in pool
matmul_fwd = dynamic_cast<
MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>(
MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
Toutput>::GetInstance()
.GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims));
if (matmul_fwd == nullptr) {
matmul_fwd =
new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
mkldnn_matmul_fwd_dims);
MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
Toutput>::GetInstance()
.SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd);
}
}
return matmul_fwd;
}
private:
MklDnnMatMulFwdPrimitiveFactory() {}
~MklDnnMatMulFwdPrimitiveFactory() {}
static MklDnnMatMulFwdPrimitiveFactory& GetInstance() {
static MklDnnMatMulFwdPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
string prefix = "matmul_fwd_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
// Generate keys for post-ops
for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
if (post_op_param.name == "relu") {
DCHECK_EQ(post_op_param.param.size(), 3);
key_creator.AddAsKey(post_op_param.name);
key_creator.AddAsKey(post_op_param.param[0]);
key_creator.AddAsKey(post_op_param.param[1]);
key_creator.AddAsKey(post_op_param.param[2]);
} else if (post_op_param.name == "output_scale") {
DCHECK_EQ(post_op_param.param.size(), 1);
key_creator.AddAsKey(post_op_param.name);
key_creator.AddAsKey(post_op_param.param[0]);
} else {
return string("not_a_key");
}
}
return key_creator.GetKey();
}
MklPrimitive* GetMklDnnMatMulFwd(
const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
string key = CreateKey(mkldnn_matmul_fwd_dims);
return this->GetOp(key);
}
void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims,
MklPrimitive* op) {
string key = CreateKey(mkldnn_matmul_fwd_dims);
this->SetOp(key, op);
}
};
template <class Toutput>
class MklDnnMatMulOpBase : public OpKernel {
public:
explicit MklDnnMatMulOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override = 0;
// Allocate output tensor.
virtual void AllocateOutputTensor(
OpKernelContext* context,
const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
const memory::dims& output_dims_mkl_order,
memory::format output_tf_format, Tensor** output_tensor) {
DCHECK(output_tensor);
auto dst_pd = mkldnn_matmul_prim_desc.dst_primitive_desc();
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(true);
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetElemType(MklDnnType<Toutput>());
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
output_dims_mkl_order, output_tf_format);
TensorShape output_tf_shape;
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
// Allocate Output Tensor
AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor,
output_tf_shape, output_mkl_shape);
}
engine cpu_engine_ = engine(engine::cpu, 0);
protected:
const int kInputIndexSrc = 0;
const int kInputIndexWeight = 1;
const int kInputIndexBias = 2;
const int kOutputIndexDst = 0;
};
} // namespace tensorflow
#endif // INTEL_MKL
#endif // TENSORFLOW_CORE_KERNELS_MKL_MATMUL_OPS_COMMON_H_

View File

@ -90,19 +90,12 @@ limitations under the License.
// https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training
#ifdef INTEL_MKL
#include "mkldnn.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
#include "tensorflow/core/kernels/mkl_quantized_conv_ops.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
using mkldnn::inner_product_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
namespace {
enum {
@ -113,299 +106,9 @@ enum {
namespace tensorflow {
// This structure aggregates multiple inputs to MklDnnMatMul* methods.
struct MklDnnMatMulFwdParams {
memory::dims src_dims;
memory::dims weight_dims;
memory::dims bias_dims;
memory::dims dst_dims;
string dtypes = string("");
struct PostOpParam {
string name;
std::vector<float> param;
};
std::vector<PostOpParam> post_op_params;
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
memory::dims bias_dims, memory::dims dst_dims)
: src_dims(src_dims),
weight_dims(weight_dims),
bias_dims(bias_dims),
dst_dims(dst_dims) {}
};
// With quantization, input, weight, bias, and output can have different types.
// So we use different template parameters for each type.
// TODO(intel-tf): The template type "T" is currently used to match the
// templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
// In future, with the removal of "T" from MklPrimitiveFactory, this class
// needs to drop "T".
template <typename T, typename Tinput, typename Tweight, typename Tbias,
typename Toutput>
class MklDnnMatMulFwdPrimitive : public MklPrimitive {
public:
explicit MklDnnMatMulFwdPrimitive(
const MklDnnMatMulFwdParams& matmulFwdParams)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// Create matmul primitive
if (context_.matmul_fwd == nullptr) {
Setup(matmulFwdParams);
}
}
~MklDnnMatMulFwdPrimitive() {}
// Inner-product forward execute with bias:
// - src_data: input data buffer of src
// - weight_data: input data buffer of weight
// - bias_data: input data buffer of bias
// - dst_data: output data buffer of dst
void Execute(const Tinput* src_data, const Tweight* weight_data,
const Tbias* bias_data, Toutput* dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data)));
context_.weight_mem->set_data_handle(
static_cast<void*>(const_cast<Tweight*>(weight_data)));
context_.bias_mem->set_data_handle(
static_cast<void*>(const_cast<Tbias*>(bias_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
context_.fwd_stream->submit(context_.fwd_primitives);
// After execution, set data handle back
context_.src_mem->set_data_handle(DummyData);
context_.weight_mem->set_data_handle(DummyData);
context_.bias_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
}
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
memory::format GetweightMemoryFormat() const { return context_.weight_fmt; }
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>
GetPrimitiveDesc() const {
return context_.fwd_pd;
}
private:
// Primitive reuse context for inner-product Fwd op
struct MklDnnMatMulFwdContext {
// Expected memory format for this primitive instance
memory::format src_fmt;
memory::format weight_fmt;
// MKL-DNN memory
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> weight_mem;
std::shared_ptr<mkldnn::memory> bias_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
// Descriptor and primitive-descriptor for forward inner-product
std::shared_ptr<mkldnn::inner_product_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> fwd_pd;
// Memory descriptors
std::shared_ptr<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> weight_md;
std::shared_ptr<mkldnn::memory::desc> bias_md;
std::shared_ptr<mkldnn::memory::desc> dst_md;
// Inner-product primitive
std::shared_ptr<mkldnn::primitive> matmul_fwd;
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
MklDnnMatMulFwdContext()
: src_fmt(memory::format::any),
weight_fmt(memory::format::any),
src_mem(nullptr),
weight_mem(nullptr),
bias_mem(nullptr),
dst_mem(nullptr),
fwd_desc(nullptr),
fwd_pd(nullptr),
src_md(nullptr),
weight_md(nullptr),
bias_md(nullptr),
matmul_fwd(nullptr),
fwd_stream(nullptr) {}
};
void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) {
// Create memory descriptors for inner-product data with no specified format
context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
MklDnnType<Tinput>(),
memory::format::any));
context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
MklDnnType<Tweight>(),
memory::format::any));
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
MklDnnType<Toutput>(),
memory::format::any));
context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
MklDnnType<Tbias>(),
memory::format::any));
// Create an inner-product
context_.fwd_desc.reset(new inner_product_forward::desc(
prop_kind::forward_inference, *context_.src_md, *context_.weight_md,
*context_.bias_md, *context_.dst_md));
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
// Check if there is any fusion as post-ops
auto const& post_op_params = matmul_fwd_params.post_op_params;
mkldnn::primitive_attr post_ops_attr;
mkldnn::post_ops post_ops;
if (!post_op_params.empty()) {
for (auto const& post_op_param : post_op_params) {
if (post_op_param.name == "relu") {
DCHECK_EQ(post_op_param.param.size(), 3);
float op_scale = post_op_param.param[0];
float op_alpha = post_op_param.param[1];
float op_beta = post_op_param.param[2];
post_ops.append_eltwise(op_scale, mkldnn::eltwise_relu, op_alpha,
op_beta);
} else if (post_op_param.name == "output_scale") {
DCHECK_EQ(post_op_param.param.size(), 1);
std::vector<float> scales;
scales.push_back(post_op_param.param[0]);
post_ops_attr.set_output_scales(0, scales);
} else {
DCHECK((post_op_param.name == "relu") ||
(post_op_param.name == "output_scale"));
}
}
post_ops_attr.set_post_ops(post_ops);
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
*context_.fwd_desc, post_ops_attr, cpu_engine_));
} else {
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
}
// Store the expected memory format
context_.src_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
context_.weight_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
// Create memory primitive based on dummy data
context_.src_mem.reset(
new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
context_.weight_mem.reset(
new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
context_.dst_mem.reset(
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
context_.bias_mem.reset(new memory({{{matmul_fwd_params.bias_dims},
MklDnnType<Tbias>(),
memory::format::x},
cpu_engine_},
DummyData));
// Create inner-product primitive
context_.matmul_fwd.reset(new inner_product_forward(
*context_.fwd_pd, *context_.src_mem, *context_.weight_mem,
*context_.bias_mem, *context_.dst_mem));
context_.fwd_primitives.push_back(*context_.matmul_fwd);
return;
}
struct MklDnnMatMulFwdContext context_;
engine cpu_engine_;
};
template <typename T, typename Tinput, typename Tweight, typename Tbias,
typename Toutput>
class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get(
const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) {
MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd =
nullptr;
if (do_not_cache) {
// Always create new primitive
matmul_fwd =
new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
mkldnn_matmul_fwd_dims);
} else {
// try to find a suitable one in pool
matmul_fwd = dynamic_cast<
MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>(
MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
Toutput>::GetInstance()
.GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims));
if (matmul_fwd == nullptr) {
matmul_fwd =
new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
mkldnn_matmul_fwd_dims);
MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
Toutput>::GetInstance()
.SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd);
}
}
return matmul_fwd;
}
private:
MklDnnMatMulFwdPrimitiveFactory() {}
~MklDnnMatMulFwdPrimitiveFactory() {}
static MklDnnMatMulFwdPrimitiveFactory& GetInstance() {
static MklDnnMatMulFwdPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
string prefix = "matmul_fwd_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
// Generate keys for post-ops
for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
if (post_op_param.name == "relu") {
DCHECK_EQ(post_op_param.param.size(), 3);
key_creator.AddAsKey(post_op_param.name);
key_creator.AddAsKey(post_op_param.param[0]);
key_creator.AddAsKey(post_op_param.param[1]);
key_creator.AddAsKey(post_op_param.param[2]);
} else if (post_op_param.name == "output_scale") {
DCHECK_EQ(post_op_param.param.size(), 1);
key_creator.AddAsKey(post_op_param.name);
key_creator.AddAsKey(post_op_param.param[0]);
} else {
return string("not_a_key");
}
}
return key_creator.GetKey();
}
MklPrimitive* GetMklDnnMatMulFwd(
const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
string key = CreateKey(mkldnn_matmul_fwd_dims);
return this->GetOp(key);
}
void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims,
MklPrimitive* op) {
string key = CreateKey(mkldnn_matmul_fwd_dims);
this->SetOp(key, op);
}
};
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Device, typename Tinput, typename Tweight, typename Tbias,
typename Toutput>
class MklDnnQuantizedMatMulOp : public OpKernel {
class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Toutput> {
public:
virtual ~MklDnnQuantizedMatMulOp() {
if (this->input_bias_ != nullptr) {
@ -430,7 +133,7 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
}
explicit MklDnnQuantizedMatMulOp(OpKernelConstruction* context)
: OpKernel(context) {
: MklDnnMatMulOpBase<Toutput>(context) {
string mode_string;
OP_REQUIRES_OK(context, context->GetAttr("input_quant_mode", &mode_string));
if (mode_string == "MIN_FIRST") {
@ -447,19 +150,20 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
void Compute(OpKernelContext* context) override {
try {
// Input tensors
const Tensor& src_tensor = MklGetInput(context, kInputIndexSrc);
const Tensor& weight_tensor = MklGetInput(context, kInputIndexWeight);
const Tensor& bias_tensor = MklGetInput(context, kInputIndexBias);
const Tensor& src_tensor = MklGetInput(context, this->kInputIndexSrc);
const Tensor& weight_tensor =
MklGetInput(context, this->kInputIndexWeight);
const Tensor& bias_tensor = MklGetInput(context, this->kInputIndexBias);
MklDnnShape src_mkl_shape, weight_mkl_shape;
GetMklShape(context, kInputIndexSrc, &src_mkl_shape);
GetMklShape(context, kInputIndexWeight, &weight_mkl_shape);
GetMklShape(context, this->kInputIndexSrc, &src_mkl_shape);
GetMklShape(context, this->kInputIndexWeight, &weight_mkl_shape);
OP_REQUIRES(context, !weight_mkl_shape.IsMklTensor(),
errors::InvalidArgument("Weight should not be in "
"MKL Layout"));
MklDnnData<Tinput> src(&cpu_engine_);
MklDnnData<Tweight> weight(&cpu_engine_);
MklDnnData<Tinput> src(&(this->cpu_engine_));
MklDnnData<Tweight> weight(&(this->cpu_engine_));
memory::dims src_dims, weight_dims;
memory::dims dst_dims_tf_order, dst_dims_mkl_order;
@ -524,8 +228,8 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
// Allocate output Tensor.
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>
matmul_fwd_pd = matmul_fwd->GetPrimitiveDesc();
AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order,
input_output_fmt, &dst_tensor);
this->AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order,
input_output_fmt, &dst_tensor);
Toutput* dst_data =
reinterpret_cast<Toutput*>(dst_tensor->flat<Toutput>().data());
@ -724,32 +428,6 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
}
}
// Allocate output tensor.
virtual void AllocateOutputTensor(
OpKernelContext* context,
const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
const memory::dims& output_dims_mkl_order,
memory::format output_tf_format, Tensor** output_tensor) {
DCHECK(output_tensor);
auto dst_pd = mkldnn_matmul_prim_desc.dst_primitive_desc();
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(true);
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetElemType(MklDnnType<Toutput>());
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
output_dims_mkl_order, output_tf_format, true);
TensorShape output_tf_shape;
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
// Allocate Output Tensor
AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor,
output_tf_shape, output_mkl_shape);
}
engine cpu_engine_ = engine(engine::cpu, 0);
private:
memory* input_bias_ = nullptr;
memory* scaled_bias_ = nullptr;
@ -757,11 +435,6 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
// Buffer to save the compensated bias
float* comp_bias_ = nullptr;
const int kInputIndexSrc = 0;
const int kInputIndexWeight = 1;
const int kInputIndexBias = 2;
const int kOutputIndexDst = 0;
int mode_;
};

View File

@ -60,6 +60,32 @@ REGISTER_OP("_MklFusedConv2D")
is expected to create these operators.
)doc");
REGISTER_OP("_MklFusedMatMul")
.Input("a: T")
.Input("b: T")
.Input("args: num_args * T")
.Input("mkl_a: uint8")
.Input("mkl_b: uint8")
.Input("mkl_args: num_args * uint8")
.Output("product: T")
.Output("mkl_product: uint8")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
.Attr("T: {float}")
.Attr("num_args: int >= 0")
.Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ----------- //
.Attr("epsilon: float = 0.0001")
// --------------------------------------------- //
.SetShapeFn(shape_inference::MatMulShape)
.Doc(R"doc(
MKL version of FusedMatMul operator. Uses MKL-DNN APIs to implement MatMul
operator.
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
REGISTER_OP("__MklDummyPadWithFusedConv2D")
.Input("input: T")
.Input("filter: T")

View File

@ -482,7 +482,7 @@ class MklDnnShape {
/// We use lazy evaluation and create it only when needed. Input format can
/// also be Blocked format.
inline void SetTfLayout(size_t dims, const memory::dims& sizes,
MKL_TENSOR_FORMAT format, bool is_2d = false) {
MKL_TENSOR_FORMAT format) {
DCHECK_EQ(dims, sizes.size())
<< "SetTfLayout: Number of dimensions does not"
"match with dimension array";
@ -492,7 +492,7 @@ class MklDnnShape {
}
data_.tf_data_format_ = format;
if (format != MKL_TENSOR_FORMAT_BLOCKED) {
if (is_2d) {
if (dims == 2) {
data_.map_[0] = MklDnnDims::Dim_N;
data_.map_[1] = MklDnnDims::Dim_C;
} else {