Merge pull request #31782 from Intel-tensorflow:tenglu/fuse_fp32_matmul
PiperOrigin-RevId: 270955757
This commit is contained in:
commit
8e3f8ae473
@ -273,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
|
||||
csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
|
||||
csinfo_.fused_conv2d = "_FusedConv2D";
|
||||
csinfo_.fused_matmul = "_FusedMatMul";
|
||||
csinfo_.identity = "Identity";
|
||||
csinfo_.leakyrelu = "LeakyRelu";
|
||||
csinfo_.leakyrelu_grad = "LeakyReluGrad";
|
||||
@ -294,6 +295,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
csinfo_.mkl_depthwise_conv2d_grad_filter =
|
||||
"_MklDepthwiseConv2dNativeBackpropFilter";
|
||||
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
|
||||
csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
|
||||
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
|
||||
csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
|
||||
csinfo_.pad = "Pad";
|
||||
@ -478,6 +480,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
|
||||
CopyAttrsFusedConv2D, FusedConv2DRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul,
|
||||
CopyAttrsAll, FusedMatMulRewrite});
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.identity,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.identity),
|
||||
@ -905,6 +910,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
|
||||
string fused_batch_norm_v3;
|
||||
string fused_batch_norm_grad_v3;
|
||||
string fused_conv2d;
|
||||
string fused_matmul;
|
||||
string identity;
|
||||
string leakyrelu;
|
||||
string leakyrelu_grad;
|
||||
@ -924,6 +930,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
|
||||
string mkl_depthwise_conv2d_grad_input;
|
||||
string mkl_depthwise_conv2d_grad_filter;
|
||||
string mkl_fused_conv2d;
|
||||
string mkl_fused_matmul;
|
||||
string mkl_pad_with_conv2d;
|
||||
string mkl_pad_with_fused_conv2d;
|
||||
string mul;
|
||||
@ -1453,6 +1460,22 @@ rinfo_.push_back({csinfo_.tanh_grad,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Rewrite rule for _FusedMatMul.
|
||||
// @return - true (no transpose attribute for input 1 and only has 1 post op);
|
||||
// false otherwise.
|
||||
static bool FusedMatMulRewrite(const Node* n) {
|
||||
bool trans_a;
|
||||
std::vector<string> fused_ops;
|
||||
|
||||
// Do not rewrite with transpose attribute because reorder has performance
|
||||
// impact.
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a));
|
||||
// Do not rewrite with more than 1 post op because MKL-DNN doesn't support.
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
|
||||
|
||||
return (!trans_a) && (fused_ops.size() == 1);
|
||||
}
|
||||
|
||||
// Check if we are performing pooling on depth or batch. If it is, then we
|
||||
// do not rewrite MaxPool node to Mkl version.
|
||||
// @return - true (if it is not a depth/batch wise pooling case);
|
||||
@ -3553,6 +3576,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
||||
n->type_string() != csinfo_.pad_with_fused_conv2d &&
|
||||
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
|
||||
n->type_string() != csinfo_.fused_conv2d &&
|
||||
n->type_string() != csinfo_.fused_matmul &&
|
||||
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
|
||||
T)) {
|
||||
return nullptr;
|
||||
|
@ -1874,6 +1874,53 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Negative2) {
|
||||
"D(_FusedConv2D);E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E");
|
||||
}
|
||||
|
||||
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_FusedMatMul_Postive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: '_FusedMatMul'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'transpose_a' value { b: false } }"
|
||||
" attr { key: 'transpose_b' value { b: false } }"
|
||||
" attr { key: 'num_args' value { i: 1 } }"
|
||||
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}"
|
||||
" input: ['A', 'B', 'C']}"
|
||||
"node { name: 'Z' op: 'Zeta'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'C']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(_MklFusedMatMul);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);Z(Zeta)"
|
||||
"|A->D;A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||
"A:control->DMT/_2:control;B->D:1;C->D:2;C->Z:1;D->Z;DMT/_0->D:3;"
|
||||
"DMT/_1->D:4;DMT/_2->D:5");
|
||||
}
|
||||
|
||||
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_FusedMatMul_Negative) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: '_FusedMatMul'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'transpose_a' value { b: true } }"
|
||||
" attr { key: 'transpose_b' value { b: false } }"
|
||||
" attr { key: 'num_args' value { i: 1 } }"
|
||||
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}"
|
||||
" input: ['A', 'B', 'C']}"
|
||||
"node { name: 'Z' op: 'Zeta'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'C']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(_FusedMatMul);Z(Zeta)"
|
||||
"|A->D;B->D:1;C->D:2;C->Z:1;D->Z");
|
||||
}
|
||||
|
||||
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
|
||||
// padding is VALID type
|
||||
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
|
||||
|
@ -263,13 +263,7 @@ bool IsGpuCompatibleConv2D(const NodeDef* conv2d) {
|
||||
|
||||
bool IsCpuCompatibleMatMul(const NodeDef* matmul) {
|
||||
DCHECK(IsMatMul(*matmul)) << "Expected MatMul op";
|
||||
#ifndef INTEL_MKL
|
||||
// Temporarily disable Matmul fusions if MKL is enabled.
|
||||
// TODO(Intel) renable Matmul fusions when enabled by MKL DNN.
|
||||
return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul);
|
||||
#else
|
||||
return false;
|
||||
#endif // !INTEL_MKL
|
||||
}
|
||||
|
||||
// Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU.
|
||||
|
@ -3825,7 +3825,11 @@ tf_kernel_library(
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
name = "mkl_matmul_op",
|
||||
srcs = ["mkl_matmul_op.cc"],
|
||||
srcs = [
|
||||
"mkl_matmul_op.cc",
|
||||
"mkl_matmul_op_fused.cc",
|
||||
],
|
||||
hdrs = ["mkl_matmul_ops_common.h"],
|
||||
deps = MATH_DEPS + mkl_deps(),
|
||||
)
|
||||
|
||||
@ -7546,6 +7550,7 @@ tf_mkl_kernel_library(
|
||||
name = "mkl_qmatmul_op",
|
||||
srcs = ["mkl_qmatmul_op.cc"],
|
||||
hdrs = [
|
||||
"mkl_matmul_ops_common.h",
|
||||
"mkl_quantized_conv_ops.h",
|
||||
"no_op.h",
|
||||
],
|
||||
@ -7860,6 +7865,7 @@ tf_cc_test_mkl(
|
||||
":conv_ops",
|
||||
":image",
|
||||
":mkl_conv_op",
|
||||
":mkl_matmul_op",
|
||||
":mkl_tfconv_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
|
@ -160,6 +160,33 @@ class CommonTestUtilities : public OpsTestBase {
|
||||
|
||||
test::ExpectClose(conv_2d, fused_conv_2d, 1e-5);
|
||||
}
|
||||
|
||||
static void VerifyFusedMatrixClose(int depth, int batch, int weight_count,
|
||||
const std::vector<string>& fused_ops,
|
||||
const FusedGraphRunner& run_default,
|
||||
const FusedGraphRunner& run_fused) {
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
|
||||
Tensor input(dtype, {batch, depth});
|
||||
input.flat<T>() = input.flat<T>().setRandom();
|
||||
|
||||
Tensor weight(dtype, {depth, weight_count});
|
||||
weight.flat<T>() = weight.flat<T>().setRandom();
|
||||
|
||||
Tensor bias(dtype, {weight_count});
|
||||
bias.flat<T>() = bias.flat<T>().setRandom();
|
||||
|
||||
Tensor output;
|
||||
Tensor fused_output;
|
||||
|
||||
run_default(input, weight, bias, fused_ops, &output);
|
||||
run_fused(input, weight, bias, fused_ops, &fused_output);
|
||||
|
||||
ASSERT_EQ(output.dtype(), fused_output.dtype());
|
||||
ASSERT_EQ(output.shape(), fused_output.shape());
|
||||
|
||||
test::ExpectClose(output, fused_output, 1e-5);
|
||||
}
|
||||
};
|
||||
|
||||
// Testing MKL's fused convolution ops
|
||||
@ -586,6 +613,97 @@ TEST_F(FilterCacheTest, Conv2DFilterCacheTest) {
|
||||
Run<float>(DT_FLOAT, image, filter, expected, true);
|
||||
}
|
||||
|
||||
// Testing fusion of MatMul and BiasAdd
|
||||
template <typename T>
|
||||
class MklFusedMatMulOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void VerifyFusedMatMul(const int kBatch, const int kInputChannel,
|
||||
const int kOutputChannel) {
|
||||
const FusedGraphRunner run_default =
|
||||
[this](const Tensor& input, const Tensor& weight, const Tensor& bias,
|
||||
const std::vector<string>& fused_ops, Tensor* output) {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
auto input_op =
|
||||
ops::Const(root.WithOpName("input"), Input::Initializer(input));
|
||||
Output next_op = ops::MatMul(root.WithOpName("matmul"), input_op,
|
||||
ops::Const(root.WithOpName("weight"),
|
||||
Input::Initializer(weight)));
|
||||
|
||||
string last_op = "";
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") !=
|
||||
fused_ops.end()) {
|
||||
last_op = "with_bias";
|
||||
next_op = ops::BiasAdd(
|
||||
root.WithOpName(last_op), next_op,
|
||||
ops::Const(root.WithOpName("bias"), Input::Initializer(bias)));
|
||||
}
|
||||
|
||||
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
|
||||
};
|
||||
|
||||
const FusedGraphRunner run_fused =
|
||||
[this](const Tensor& input, const Tensor& weight, const Tensor& bias,
|
||||
const std::vector<string>& fused_ops, Tensor* output) {
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
const int num_args = fused_ops.size();
|
||||
|
||||
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(num_args, dtype))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(num_args, DT_UINT8))
|
||||
.Attr("T", dtype)
|
||||
.Attr("transpose_a", false)
|
||||
.Attr("transpose_b", false)
|
||||
.Attr("num_args", num_args)
|
||||
.Attr("fused_ops", fused_ops)
|
||||
.Attr("epsilon", 0.0001)
|
||||
.Attr("_kernel", "MklLayoutDependentOp")
|
||||
.Finalize(node_def()));
|
||||
|
||||
TF_EXPECT_OK(InitOp());
|
||||
|
||||
AddInputFromArray<T>(input.shape(), input.flat<T>());
|
||||
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
|
||||
AddInputFromArray<T>(bias.shape(), bias.flat<T>());
|
||||
// Add MKL meta input for input, filter and bias.
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
const Tensor& output_tensor = *GetOutput(0);
|
||||
const Tensor& output_meta_tensor = *GetOutput(1);
|
||||
CommonTestUtilities<T> test_util;
|
||||
test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
|
||||
output);
|
||||
};
|
||||
|
||||
CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch,
|
||||
kOutputChannel, {"BiasAdd"},
|
||||
run_default, run_fused);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE_P(MklFusedMatMulOpTest);
|
||||
|
||||
TYPED_TEST_P(MklFusedMatMulOpTest, BasicTest) {
|
||||
const int batch = 3;
|
||||
const int input_channel = 4;
|
||||
const int output_channel = 5;
|
||||
|
||||
this->VerifyFusedMatMul(batch, input_channel, output_channel);
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_CASE_P(MklFusedMatMulOpTest, BasicTest);
|
||||
|
||||
using MklFusedMatMulDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedMatMulOpTest,
|
||||
MklFusedMatMulDataTypes);
|
||||
|
||||
class BiasCacheTest : public OpsTestBase {
|
||||
public:
|
||||
template <typename T>
|
||||
@ -715,6 +833,7 @@ TEST_F(BiasCacheTest, Conv2DBiasCacheTest) {
|
||||
Run<float>(DT_QUINT8, image, filter, bias, min_input, max_input, min_filter,
|
||||
max_filter, min_output, max_output, expected, true);
|
||||
}
|
||||
|
||||
// Testing fusion of pad and fusedconv2d
|
||||
template <typename T>
|
||||
class MklPadWithFusedConv2DOpTest : public OpsTestBase {
|
||||
|
193
tensorflow/core/kernels/mkl_matmul_op_fused.cc
Normal file
193
tensorflow/core/kernels/mkl_matmul_op_fused.cc
Normal file
@ -0,0 +1,193 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/math_ops.cc.
|
||||
|
||||
// This file uses MKL-DNN InnerProduct for acceleration of TF Matrix-Matrix
|
||||
// Multiplication (MatMul) with bias (BiasAdd) operations.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Fuse Operation
|
||||
template <typename Device, typename T>
|
||||
class MklFusedMatMulOp : public MklDnnMatMulOpBase<T> {
|
||||
public:
|
||||
explicit MklFusedMatMulOp(OpKernelConstruction* ctx)
|
||||
: MklDnnMatMulOpBase<T>(ctx) {
|
||||
std::vector<string> fused_ops;
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("fused_ops", &fused_ops));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
|
||||
|
||||
OP_REQUIRES(ctx, fused_ops.size() == 1,
|
||||
errors::InvalidArgument(
|
||||
"MklFusedMatMul must have only one argument: bias."));
|
||||
OP_REQUIRES(
|
||||
ctx, transpose_a_ == false,
|
||||
errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// FusedMatMul has 3 inputs: src, weights, bias
|
||||
const Tensor& src_tensor = ctx->input(this->kInputIndexSrc);
|
||||
const Tensor& weight_tensor = ctx->input(this->kInputIndexWeight);
|
||||
const Tensor& bias_tensor = MklGetInput(ctx, this->kInputIndexBias);
|
||||
|
||||
MklDnnShape src_mkl_shape;
|
||||
MklDnnShape weight_mkl_shape;
|
||||
GetMklShape(ctx, this->kInputIndexSrc, &src_mkl_shape);
|
||||
GetMklShape(ctx, this->kInputIndexWeight, &weight_mkl_shape);
|
||||
|
||||
// Get shapes of input tensors
|
||||
auto src_tf_shape = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape()
|
||||
: src_tensor.shape();
|
||||
auto weight_tf_shape = weight_mkl_shape.IsMklTensor()
|
||||
? weight_mkl_shape.GetTfShape()
|
||||
: weight_tensor.shape();
|
||||
|
||||
// Check the constraint of input matrix and bias
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(src_tf_shape),
|
||||
errors::InvalidArgument("In[0] is not a matrix"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(weight_tf_shape),
|
||||
errors::InvalidArgument("In[1] is not a matrix"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias_tensor.shape()),
|
||||
errors::InvalidArgument("Biases must be 1D"));
|
||||
|
||||
// Expression: [batch, k] * [k, channel] + [channel] = [batch, channel]
|
||||
//
|
||||
// Get dimension size of each matrix, dim_pair[] is the location of k
|
||||
// in the inputs, we have constraint that k of the two inputs are
|
||||
// the same
|
||||
const int dim_pair[] = {1, transpose_b_ ? 1 : 0};
|
||||
const int batch = src_tf_shape.dim_size(1 - dim_pair[0]);
|
||||
const int k = src_tf_shape.dim_size(dim_pair[0]);
|
||||
const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]);
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
|
||||
errors::InvalidArgument(
|
||||
"Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(),
|
||||
", In[1]: ", weight_tf_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel,
|
||||
errors::InvalidArgument(
|
||||
"Must provide as many biases as the channel size: ",
|
||||
bias_tensor.shape().DebugString(), " vs. ", channel));
|
||||
|
||||
// For inputs s[batch, k], w[k, channel] and b[channel], the primitive
|
||||
// dims should be described like this:
|
||||
// s[batch, k] * w^T[channel, k] + b[channel] = dst[batch, channel]
|
||||
// [n, ic] * [oc, ic] + [oc] = [n, oc]
|
||||
memory::dims src_dims = memory::dims({batch, k});
|
||||
// Reverse the weights dims from [k, channel] to [channel, k].
|
||||
memory::dims weight_dims = memory::dims({channel, k});
|
||||
memory::dims bias_dims = memory::dims({channel});
|
||||
memory::dims dst_dims = memory::dims({batch, channel});
|
||||
memory::format weight_format =
|
||||
transpose_b_ ? memory::format::oi : memory::format::io;
|
||||
|
||||
MklDnnMatMulFwdParams matmul_params(src_dims, weight_dims, bias_dims,
|
||||
dst_dims, weight_format);
|
||||
MklDnnMatMulFwdPrimitive<T, T, T, T, T>* matmul_prim =
|
||||
MklDnnMatMulFwdPrimitiveFactory<T, T, T, T, T>::Get(matmul_params, 0);
|
||||
|
||||
// Allocate output tensor.
|
||||
Tensor* dst_tensor = nullptr;
|
||||
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd =
|
||||
matmul_prim->GetPrimitiveDesc();
|
||||
|
||||
if (src_mkl_shape.IsMklTensor() && weight_mkl_shape.IsMklTensor()) {
|
||||
this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims, memory::format::nc,
|
||||
&dst_tensor);
|
||||
} else {
|
||||
TensorShape dst_tensor_shape({batch, channel});
|
||||
MklDnnShape dst_mkl_shape;
|
||||
dst_mkl_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape,
|
||||
dst_mkl_shape);
|
||||
}
|
||||
|
||||
// if there's nothing to compute, just return.
|
||||
if (batch == 0 || channel == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Prepare the input and output for primitive.
|
||||
T* src_data = const_cast<T*>(src_tensor.flat<T>().data());
|
||||
T* weight_data = const_cast<T*>(weight_tensor.flat<T>().data());
|
||||
T* bias_data = const_cast<T*>(bias_tensor.flat<T>().data());
|
||||
T* dst_data = const_cast<T*>(dst_tensor->flat<T>().data());
|
||||
|
||||
// Any input is MKL format, reorder it if necessary.
|
||||
MklDnnData<T> src_mkl(&(this->cpu_engine_));
|
||||
MklDnnData<T> weight_mkl(&(this->cpu_engine_));
|
||||
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
memory::desc input_md = src_mkl_shape.GetMklLayout();
|
||||
|
||||
if (input_md.data.format != memory::format::nc) {
|
||||
src_mkl.SetUsrMem(input_md, src_data);
|
||||
src_mkl.CheckReorderToOpMem(matmul_pd.get()->src_primitive_desc());
|
||||
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
||||
}
|
||||
}
|
||||
|
||||
if (weight_mkl_shape.IsMklTensor()) {
|
||||
memory::desc input_md = weight_mkl_shape.GetMklLayout();
|
||||
|
||||
if (input_md.data.format != weight_format) {
|
||||
weight_mkl.SetUsrMem(input_md, weight_data);
|
||||
weight_mkl.CheckReorderToOpMem(
|
||||
matmul_pd.get()->weights_primitive_desc());
|
||||
weight_data =
|
||||
reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle());
|
||||
}
|
||||
}
|
||||
|
||||
matmul_prim->Execute(src_data, weight_data, bias_data, dst_data);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, errors::Aborted("Operation received an exception:", error_msg));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool transpose_a_;
|
||||
bool transpose_b_;
|
||||
};
|
||||
|
||||
// Register mkl kernels for supported operations and types.
|
||||
#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklFusedMatMul") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklFusedMatMulOp<CPUDevice, type>);
|
||||
TF_CALL_float(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
372
tensorflow/core/kernels/mkl_matmul_ops_common.h
Normal file
372
tensorflow/core/kernels/mkl_matmul_ops_common.h
Normal file
@ -0,0 +1,372 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MKL_MATMUL_OPS_COMMON_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MKL_MATMUL_OPS_COMMON_H_
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
using mkldnn::inner_product_forward;
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
// This structure aggregates multiple inputs to MklDnnMatMul* methods.
|
||||
struct MklDnnMatMulFwdParams {
|
||||
memory::dims src_dims;
|
||||
memory::dims weight_dims;
|
||||
memory::dims bias_dims;
|
||||
memory::dims dst_dims;
|
||||
memory::format weight_fmt;
|
||||
string dtypes = string("");
|
||||
struct PostOpParam {
|
||||
string name;
|
||||
std::vector<float> param;
|
||||
};
|
||||
std::vector<PostOpParam> post_op_params;
|
||||
|
||||
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
|
||||
memory::dims bias_dims, memory::dims dst_dims,
|
||||
memory::format weight_fmt = memory::format::any)
|
||||
: src_dims(src_dims),
|
||||
weight_dims(weight_dims),
|
||||
bias_dims(bias_dims),
|
||||
dst_dims(dst_dims),
|
||||
weight_fmt(weight_fmt) {}
|
||||
};
|
||||
|
||||
// With quantization, input, weight, bias, and output can have different types.
|
||||
// So we use different template parameters for each type.
|
||||
// TODO(intel-tf): The template type "T" is currently used to match the
|
||||
// templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
|
||||
// In the future, with the removal of "T" from MklPrimitiveFactory, this class
|
||||
// needs to drop "T".
|
||||
template <typename T, typename Tinput, typename Tweight, typename Tbias,
|
||||
typename Toutput>
|
||||
class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklDnnMatMulFwdPrimitive(
|
||||
const MklDnnMatMulFwdParams& matmulFwdParams)
|
||||
: cpu_engine_(engine::cpu, 0) {
|
||||
context_.fwd_stream.reset(new stream(stream::kind::eager));
|
||||
// Create matmul primitive
|
||||
if (context_.matmul_fwd == nullptr) {
|
||||
Setup(matmulFwdParams);
|
||||
}
|
||||
}
|
||||
|
||||
~MklDnnMatMulFwdPrimitive() {}
|
||||
|
||||
// Inner-product forward execute with bias:
|
||||
// - src_data: input data buffer of src
|
||||
// - weight_data: input data buffer of weight
|
||||
// - bias_data: input data buffer of bias
|
||||
// - dst_data: output data buffer of dst
|
||||
void Execute(const Tinput* src_data, const Tweight* weight_data,
|
||||
const Tbias* bias_data, Toutput* dst_data) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
||||
context_.weight_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tweight*>(weight_data)));
|
||||
context_.bias_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tbias*>(bias_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
|
||||
// After execution, set data handle back
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.weight_mem->set_data_handle(DummyData);
|
||||
context_.bias_mem->set_data_handle(DummyData);
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
}
|
||||
|
||||
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
||||
memory::format GetweightMemoryFormat() const { return context_.weight_fmt; }
|
||||
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>
|
||||
GetPrimitiveDesc() const {
|
||||
return context_.fwd_pd;
|
||||
}
|
||||
|
||||
private:
|
||||
// Primitive reuse context for inner-product Fwd op
|
||||
struct MklDnnMatMulFwdContext {
|
||||
// Expected memory format for this primitive instance
|
||||
memory::format src_fmt;
|
||||
memory::format weight_fmt;
|
||||
|
||||
// MKL-DNN memory
|
||||
std::shared_ptr<mkldnn::memory> src_mem;
|
||||
std::shared_ptr<mkldnn::memory> weight_mem;
|
||||
std::shared_ptr<mkldnn::memory> bias_mem;
|
||||
std::shared_ptr<mkldnn::memory> dst_mem;
|
||||
|
||||
// Descriptor and primitive-descriptor for forward inner-product
|
||||
std::shared_ptr<mkldnn::inner_product_forward::desc> fwd_desc;
|
||||
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> fwd_pd;
|
||||
|
||||
// Memory descriptors
|
||||
std::shared_ptr<mkldnn::memory::desc> src_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> weight_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> bias_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> dst_md;
|
||||
|
||||
// Inner-product primitive
|
||||
std::shared_ptr<mkldnn::primitive> matmul_fwd;
|
||||
std::shared_ptr<mkldnn::stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
MklDnnMatMulFwdContext()
|
||||
: src_fmt(memory::format::any),
|
||||
weight_fmt(memory::format::any),
|
||||
src_mem(nullptr),
|
||||
weight_mem(nullptr),
|
||||
bias_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
fwd_desc(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
src_md(nullptr),
|
||||
weight_md(nullptr),
|
||||
bias_md(nullptr),
|
||||
dst_md(nullptr),
|
||||
matmul_fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
};
|
||||
|
||||
void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) {
|
||||
// Create memory descriptors for inner-product data with no specified format
|
||||
context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
|
||||
MklDnnType<Tinput>(),
|
||||
memory::format::any));
|
||||
|
||||
context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
|
||||
MklDnnType<Tweight>(),
|
||||
matmul_fwd_params.weight_fmt));
|
||||
|
||||
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
|
||||
MklDnnType<Toutput>(),
|
||||
memory::format::any));
|
||||
|
||||
context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
|
||||
MklDnnType<Tbias>(),
|
||||
memory::format::any));
|
||||
// Create an inner-product
|
||||
context_.fwd_desc.reset(new inner_product_forward::desc(
|
||||
prop_kind::forward_inference, *context_.src_md, *context_.weight_md,
|
||||
*context_.bias_md, *context_.dst_md));
|
||||
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
// Check if there is any fusion as post-ops
|
||||
auto const& post_op_params = matmul_fwd_params.post_op_params;
|
||||
mkldnn::primitive_attr post_ops_attr;
|
||||
mkldnn::post_ops post_ops;
|
||||
if (!post_op_params.empty()) {
|
||||
for (auto const& post_op_param : post_op_params) {
|
||||
if (post_op_param.name == "relu") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 3);
|
||||
float op_scale = post_op_param.param[0];
|
||||
float op_alpha = post_op_param.param[1];
|
||||
float op_beta = post_op_param.param[2];
|
||||
post_ops.append_eltwise(op_scale, mkldnn::eltwise_relu, op_alpha,
|
||||
op_beta);
|
||||
} else if (post_op_param.name == "output_scale") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
std::vector<float> scales;
|
||||
scales.push_back(post_op_param.param[0]);
|
||||
post_ops_attr.set_output_scales(0, scales);
|
||||
} else {
|
||||
DCHECK((post_op_param.name == "relu") ||
|
||||
(post_op_param.name == "output_scale"));
|
||||
}
|
||||
}
|
||||
post_ops_attr.set_post_ops(post_ops);
|
||||
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
|
||||
*context_.fwd_desc, post_ops_attr, cpu_engine_));
|
||||
} else {
|
||||
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
}
|
||||
|
||||
// Store the expected memory format
|
||||
context_.src_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
|
||||
|
||||
context_.weight_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
|
||||
|
||||
// Create memory primitive based on dummy data
|
||||
context_.src_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
|
||||
context_.weight_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
|
||||
context_.bias_mem.reset(new memory({{{matmul_fwd_params.bias_dims},
|
||||
MklDnnType<Tbias>(),
|
||||
memory::format::x},
|
||||
cpu_engine_},
|
||||
DummyData));
|
||||
|
||||
// Create inner-product primitive
|
||||
context_.matmul_fwd.reset(new inner_product_forward(
|
||||
*context_.fwd_pd, *context_.src_mem, *context_.weight_mem,
|
||||
*context_.bias_mem, *context_.dst_mem));
|
||||
|
||||
context_.fwd_primitives.push_back(*context_.matmul_fwd);
|
||||
return;
|
||||
}
|
||||
|
||||
struct MklDnnMatMulFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T, typename Tinput, typename Tweight, typename Tbias,
|
||||
typename Toutput>
|
||||
class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
public:
|
||||
static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get(
|
||||
const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) {
|
||||
MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd =
|
||||
nullptr;
|
||||
|
||||
if (do_not_cache) {
|
||||
// Always create new primitive
|
||||
matmul_fwd =
|
||||
new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
|
||||
mkldnn_matmul_fwd_dims);
|
||||
} else {
|
||||
// Try to find a suitable one in pool
|
||||
matmul_fwd = dynamic_cast<
|
||||
MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>(
|
||||
MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
|
||||
Toutput>::GetInstance()
|
||||
.GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims));
|
||||
if (matmul_fwd == nullptr) {
|
||||
matmul_fwd =
|
||||
new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
|
||||
mkldnn_matmul_fwd_dims);
|
||||
MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
|
||||
Toutput>::GetInstance()
|
||||
.SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd);
|
||||
}
|
||||
}
|
||||
return matmul_fwd;
|
||||
}
|
||||
|
||||
private:
|
||||
MklDnnMatMulFwdPrimitiveFactory() {}
|
||||
~MklDnnMatMulFwdPrimitiveFactory() {}
|
||||
|
||||
static MklDnnMatMulFwdPrimitiveFactory& GetInstance() {
|
||||
static MklDnnMatMulFwdPrimitiveFactory instance_;
|
||||
return instance_;
|
||||
}
|
||||
|
||||
static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
|
||||
string prefix = "matmul_fwd_";
|
||||
FactoryKeyCreator key_creator;
|
||||
key_creator.AddAsKey(prefix);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
|
||||
|
||||
// Generate keys for post-ops
|
||||
for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
|
||||
if (post_op_param.name == "relu") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 3);
|
||||
key_creator.AddAsKey(post_op_param.name);
|
||||
key_creator.AddAsKey(post_op_param.param[0]);
|
||||
key_creator.AddAsKey(post_op_param.param[1]);
|
||||
key_creator.AddAsKey(post_op_param.param[2]);
|
||||
} else if (post_op_param.name == "output_scale") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
key_creator.AddAsKey(post_op_param.name);
|
||||
key_creator.AddAsKey(post_op_param.param[0]);
|
||||
} else {
|
||||
return string("not_a_key");
|
||||
}
|
||||
}
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
MklPrimitive* GetMklDnnMatMulFwd(
|
||||
const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
|
||||
string key = CreateKey(mkldnn_matmul_fwd_dims);
|
||||
return this->GetOp(key);
|
||||
}
|
||||
|
||||
void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims,
|
||||
MklPrimitive* op) {
|
||||
string key = CreateKey(mkldnn_matmul_fwd_dims);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Toutput>
|
||||
class MklDnnMatMulOpBase : public OpKernel {
|
||||
public:
|
||||
explicit MklDnnMatMulOpBase(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
void Compute(OpKernelContext* context) override = 0;
|
||||
|
||||
// Allocate output tensor.
|
||||
virtual void AllocateOutputTensor(
|
||||
OpKernelContext* context,
|
||||
const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
memory::format output_tf_format, Tensor** output_tensor) {
|
||||
DCHECK(output_tensor);
|
||||
auto dst_pd = mkldnn_matmul_prim_desc.dst_primitive_desc();
|
||||
|
||||
MklDnnShape output_mkl_shape;
|
||||
output_mkl_shape.SetMklTensor(true);
|
||||
output_mkl_shape.SetMklLayout(&dst_pd);
|
||||
output_mkl_shape.SetElemType(MklDnnType<Toutput>());
|
||||
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
|
||||
output_dims_mkl_order, output_tf_format);
|
||||
|
||||
TensorShape output_tf_shape;
|
||||
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
|
||||
|
||||
// Allocate Output Tensor
|
||||
AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor,
|
||||
output_tf_shape, output_mkl_shape);
|
||||
}
|
||||
|
||||
engine cpu_engine_ = engine(engine::cpu, 0);
|
||||
|
||||
protected:
|
||||
const int kInputIndexSrc = 0;
|
||||
const int kInputIndexWeight = 1;
|
||||
const int kInputIndexBias = 2;
|
||||
const int kOutputIndexDst = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MKL_MATMUL_OPS_COMMON_H_
|
@ -90,19 +90,12 @@ limitations under the License.
|
||||
// https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "mkldnn.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
|
||||
#include "tensorflow/core/kernels/mkl_quantized_conv_ops.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
using mkldnn::inner_product_forward;
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
|
||||
namespace {
|
||||
enum {
|
||||
@ -113,299 +106,9 @@ enum {
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This structure aggregates multiple inputs to MklDnnMatMul* methods.
|
||||
struct MklDnnMatMulFwdParams {
|
||||
memory::dims src_dims;
|
||||
memory::dims weight_dims;
|
||||
memory::dims bias_dims;
|
||||
memory::dims dst_dims;
|
||||
string dtypes = string("");
|
||||
struct PostOpParam {
|
||||
string name;
|
||||
std::vector<float> param;
|
||||
};
|
||||
std::vector<PostOpParam> post_op_params;
|
||||
|
||||
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
|
||||
memory::dims bias_dims, memory::dims dst_dims)
|
||||
: src_dims(src_dims),
|
||||
weight_dims(weight_dims),
|
||||
bias_dims(bias_dims),
|
||||
dst_dims(dst_dims) {}
|
||||
};
|
||||
|
||||
// With quantization, input, weight, bias, and output can have different types.
|
||||
// So we use different template parameters for each type.
|
||||
// TODO(intel-tf): The template type "T" is currently used to match the
|
||||
// templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
|
||||
// In future, with the removal of "T" from MklPrimitiveFactory, this class
|
||||
// needs to drop "T".
|
||||
template <typename T, typename Tinput, typename Tweight, typename Tbias,
|
||||
typename Toutput>
|
||||
class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklDnnMatMulFwdPrimitive(
|
||||
const MklDnnMatMulFwdParams& matmulFwdParams)
|
||||
: cpu_engine_(engine::cpu, 0) {
|
||||
context_.fwd_stream.reset(new stream(stream::kind::eager));
|
||||
// Create matmul primitive
|
||||
if (context_.matmul_fwd == nullptr) {
|
||||
Setup(matmulFwdParams);
|
||||
}
|
||||
}
|
||||
|
||||
~MklDnnMatMulFwdPrimitive() {}
|
||||
|
||||
// Inner-product forward execute with bias:
|
||||
// - src_data: input data buffer of src
|
||||
// - weight_data: input data buffer of weight
|
||||
// - bias_data: input data buffer of bias
|
||||
// - dst_data: output data buffer of dst
|
||||
void Execute(const Tinput* src_data, const Tweight* weight_data,
|
||||
const Tbias* bias_data, Toutput* dst_data) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
||||
context_.weight_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tweight*>(weight_data)));
|
||||
context_.bias_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tbias*>(bias_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
|
||||
// After execution, set data handle back
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.weight_mem->set_data_handle(DummyData);
|
||||
context_.bias_mem->set_data_handle(DummyData);
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
}
|
||||
|
||||
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
||||
memory::format GetweightMemoryFormat() const { return context_.weight_fmt; }
|
||||
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>
|
||||
GetPrimitiveDesc() const {
|
||||
return context_.fwd_pd;
|
||||
}
|
||||
|
||||
private:
|
||||
// Primitive reuse context for inner-product Fwd op
|
||||
struct MklDnnMatMulFwdContext {
|
||||
// Expected memory format for this primitive instance
|
||||
memory::format src_fmt;
|
||||
memory::format weight_fmt;
|
||||
|
||||
// MKL-DNN memory
|
||||
std::shared_ptr<mkldnn::memory> src_mem;
|
||||
std::shared_ptr<mkldnn::memory> weight_mem;
|
||||
std::shared_ptr<mkldnn::memory> bias_mem;
|
||||
std::shared_ptr<mkldnn::memory> dst_mem;
|
||||
|
||||
// Descriptor and primitive-descriptor for forward inner-product
|
||||
std::shared_ptr<mkldnn::inner_product_forward::desc> fwd_desc;
|
||||
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> fwd_pd;
|
||||
|
||||
// Memory descriptors
|
||||
std::shared_ptr<mkldnn::memory::desc> src_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> weight_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> bias_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> dst_md;
|
||||
|
||||
// Inner-product primitive
|
||||
std::shared_ptr<mkldnn::primitive> matmul_fwd;
|
||||
std::shared_ptr<mkldnn::stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
MklDnnMatMulFwdContext()
|
||||
: src_fmt(memory::format::any),
|
||||
weight_fmt(memory::format::any),
|
||||
src_mem(nullptr),
|
||||
weight_mem(nullptr),
|
||||
bias_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
fwd_desc(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
src_md(nullptr),
|
||||
weight_md(nullptr),
|
||||
bias_md(nullptr),
|
||||
matmul_fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
};
|
||||
|
||||
void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) {
|
||||
// Create memory descriptors for inner-product data with no specified format
|
||||
context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
|
||||
MklDnnType<Tinput>(),
|
||||
memory::format::any));
|
||||
|
||||
context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
|
||||
MklDnnType<Tweight>(),
|
||||
memory::format::any));
|
||||
|
||||
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
|
||||
MklDnnType<Toutput>(),
|
||||
memory::format::any));
|
||||
|
||||
context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
|
||||
MklDnnType<Tbias>(),
|
||||
memory::format::any));
|
||||
// Create an inner-product
|
||||
context_.fwd_desc.reset(new inner_product_forward::desc(
|
||||
prop_kind::forward_inference, *context_.src_md, *context_.weight_md,
|
||||
*context_.bias_md, *context_.dst_md));
|
||||
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
// Check if there is any fusion as post-ops
|
||||
auto const& post_op_params = matmul_fwd_params.post_op_params;
|
||||
mkldnn::primitive_attr post_ops_attr;
|
||||
mkldnn::post_ops post_ops;
|
||||
if (!post_op_params.empty()) {
|
||||
for (auto const& post_op_param : post_op_params) {
|
||||
if (post_op_param.name == "relu") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 3);
|
||||
float op_scale = post_op_param.param[0];
|
||||
float op_alpha = post_op_param.param[1];
|
||||
float op_beta = post_op_param.param[2];
|
||||
post_ops.append_eltwise(op_scale, mkldnn::eltwise_relu, op_alpha,
|
||||
op_beta);
|
||||
} else if (post_op_param.name == "output_scale") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
std::vector<float> scales;
|
||||
scales.push_back(post_op_param.param[0]);
|
||||
post_ops_attr.set_output_scales(0, scales);
|
||||
} else {
|
||||
DCHECK((post_op_param.name == "relu") ||
|
||||
(post_op_param.name == "output_scale"));
|
||||
}
|
||||
}
|
||||
post_ops_attr.set_post_ops(post_ops);
|
||||
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
|
||||
*context_.fwd_desc, post_ops_attr, cpu_engine_));
|
||||
} else {
|
||||
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
}
|
||||
|
||||
// Store the expected memory format
|
||||
context_.src_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
|
||||
|
||||
context_.weight_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
|
||||
|
||||
// Create memory primitive based on dummy data
|
||||
context_.src_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
|
||||
context_.weight_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
|
||||
context_.bias_mem.reset(new memory({{{matmul_fwd_params.bias_dims},
|
||||
MklDnnType<Tbias>(),
|
||||
memory::format::x},
|
||||
cpu_engine_},
|
||||
DummyData));
|
||||
|
||||
// Create inner-product primitive
|
||||
context_.matmul_fwd.reset(new inner_product_forward(
|
||||
*context_.fwd_pd, *context_.src_mem, *context_.weight_mem,
|
||||
*context_.bias_mem, *context_.dst_mem));
|
||||
|
||||
context_.fwd_primitives.push_back(*context_.matmul_fwd);
|
||||
return;
|
||||
}
|
||||
|
||||
struct MklDnnMatMulFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T, typename Tinput, typename Tweight, typename Tbias,
|
||||
typename Toutput>
|
||||
class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
public:
|
||||
static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get(
|
||||
const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) {
|
||||
MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd =
|
||||
nullptr;
|
||||
|
||||
if (do_not_cache) {
|
||||
// Always create new primitive
|
||||
matmul_fwd =
|
||||
new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
|
||||
mkldnn_matmul_fwd_dims);
|
||||
} else {
|
||||
// try to find a suitable one in pool
|
||||
matmul_fwd = dynamic_cast<
|
||||
MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>(
|
||||
MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
|
||||
Toutput>::GetInstance()
|
||||
.GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims));
|
||||
if (matmul_fwd == nullptr) {
|
||||
matmul_fwd =
|
||||
new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
|
||||
mkldnn_matmul_fwd_dims);
|
||||
MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
|
||||
Toutput>::GetInstance()
|
||||
.SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd);
|
||||
}
|
||||
}
|
||||
return matmul_fwd;
|
||||
}
|
||||
|
||||
private:
|
||||
MklDnnMatMulFwdPrimitiveFactory() {}
|
||||
~MklDnnMatMulFwdPrimitiveFactory() {}
|
||||
|
||||
static MklDnnMatMulFwdPrimitiveFactory& GetInstance() {
|
||||
static MklDnnMatMulFwdPrimitiveFactory instance_;
|
||||
return instance_;
|
||||
}
|
||||
|
||||
static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
|
||||
string prefix = "matmul_fwd_";
|
||||
FactoryKeyCreator key_creator;
|
||||
key_creator.AddAsKey(prefix);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
|
||||
|
||||
// Generate keys for post-ops
|
||||
for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
|
||||
if (post_op_param.name == "relu") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 3);
|
||||
key_creator.AddAsKey(post_op_param.name);
|
||||
key_creator.AddAsKey(post_op_param.param[0]);
|
||||
key_creator.AddAsKey(post_op_param.param[1]);
|
||||
key_creator.AddAsKey(post_op_param.param[2]);
|
||||
} else if (post_op_param.name == "output_scale") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
key_creator.AddAsKey(post_op_param.name);
|
||||
key_creator.AddAsKey(post_op_param.param[0]);
|
||||
} else {
|
||||
return string("not_a_key");
|
||||
}
|
||||
}
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
MklPrimitive* GetMklDnnMatMulFwd(
|
||||
const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
|
||||
string key = CreateKey(mkldnn_matmul_fwd_dims);
|
||||
return this->GetOp(key);
|
||||
}
|
||||
|
||||
void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims,
|
||||
MklPrimitive* op) {
|
||||
string key = CreateKey(mkldnn_matmul_fwd_dims);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
};
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
template <typename Device, typename Tinput, typename Tweight, typename Tbias,
|
||||
typename Toutput>
|
||||
class MklDnnQuantizedMatMulOp : public OpKernel {
|
||||
class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Toutput> {
|
||||
public:
|
||||
virtual ~MklDnnQuantizedMatMulOp() {
|
||||
if (this->input_bias_ != nullptr) {
|
||||
@ -430,7 +133,7 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
|
||||
}
|
||||
|
||||
explicit MklDnnQuantizedMatMulOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
: MklDnnMatMulOpBase<Toutput>(context) {
|
||||
string mode_string;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("input_quant_mode", &mode_string));
|
||||
if (mode_string == "MIN_FIRST") {
|
||||
@ -447,19 +150,20 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
try {
|
||||
// Input tensors
|
||||
const Tensor& src_tensor = MklGetInput(context, kInputIndexSrc);
|
||||
const Tensor& weight_tensor = MklGetInput(context, kInputIndexWeight);
|
||||
const Tensor& bias_tensor = MklGetInput(context, kInputIndexBias);
|
||||
const Tensor& src_tensor = MklGetInput(context, this->kInputIndexSrc);
|
||||
const Tensor& weight_tensor =
|
||||
MklGetInput(context, this->kInputIndexWeight);
|
||||
const Tensor& bias_tensor = MklGetInput(context, this->kInputIndexBias);
|
||||
|
||||
MklDnnShape src_mkl_shape, weight_mkl_shape;
|
||||
GetMklShape(context, kInputIndexSrc, &src_mkl_shape);
|
||||
GetMklShape(context, kInputIndexWeight, &weight_mkl_shape);
|
||||
GetMklShape(context, this->kInputIndexSrc, &src_mkl_shape);
|
||||
GetMklShape(context, this->kInputIndexWeight, &weight_mkl_shape);
|
||||
OP_REQUIRES(context, !weight_mkl_shape.IsMklTensor(),
|
||||
errors::InvalidArgument("Weight should not be in "
|
||||
"MKL Layout"));
|
||||
|
||||
MklDnnData<Tinput> src(&cpu_engine_);
|
||||
MklDnnData<Tweight> weight(&cpu_engine_);
|
||||
MklDnnData<Tinput> src(&(this->cpu_engine_));
|
||||
MklDnnData<Tweight> weight(&(this->cpu_engine_));
|
||||
|
||||
memory::dims src_dims, weight_dims;
|
||||
memory::dims dst_dims_tf_order, dst_dims_mkl_order;
|
||||
@ -524,8 +228,8 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
|
||||
// Allocate output Tensor.
|
||||
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc>
|
||||
matmul_fwd_pd = matmul_fwd->GetPrimitiveDesc();
|
||||
AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order,
|
||||
input_output_fmt, &dst_tensor);
|
||||
this->AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order,
|
||||
input_output_fmt, &dst_tensor);
|
||||
|
||||
Toutput* dst_data =
|
||||
reinterpret_cast<Toutput*>(dst_tensor->flat<Toutput>().data());
|
||||
@ -724,32 +428,6 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate output tensor.
|
||||
virtual void AllocateOutputTensor(
|
||||
OpKernelContext* context,
|
||||
const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
memory::format output_tf_format, Tensor** output_tensor) {
|
||||
DCHECK(output_tensor);
|
||||
auto dst_pd = mkldnn_matmul_prim_desc.dst_primitive_desc();
|
||||
|
||||
MklDnnShape output_mkl_shape;
|
||||
output_mkl_shape.SetMklTensor(true);
|
||||
output_mkl_shape.SetMklLayout(&dst_pd);
|
||||
output_mkl_shape.SetElemType(MklDnnType<Toutput>());
|
||||
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
|
||||
output_dims_mkl_order, output_tf_format, true);
|
||||
|
||||
TensorShape output_tf_shape;
|
||||
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
|
||||
|
||||
// Allocate Output Tensor
|
||||
AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor,
|
||||
output_tf_shape, output_mkl_shape);
|
||||
}
|
||||
|
||||
engine cpu_engine_ = engine(engine::cpu, 0);
|
||||
|
||||
private:
|
||||
memory* input_bias_ = nullptr;
|
||||
memory* scaled_bias_ = nullptr;
|
||||
@ -757,11 +435,6 @@ class MklDnnQuantizedMatMulOp : public OpKernel {
|
||||
// Buffer to save the compensated bias
|
||||
float* comp_bias_ = nullptr;
|
||||
|
||||
const int kInputIndexSrc = 0;
|
||||
const int kInputIndexWeight = 1;
|
||||
const int kInputIndexBias = 2;
|
||||
const int kOutputIndexDst = 0;
|
||||
|
||||
int mode_;
|
||||
};
|
||||
|
||||
|
@ -60,6 +60,32 @@ REGISTER_OP("_MklFusedConv2D")
|
||||
is expected to create these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_MklFusedMatMul")
|
||||
.Input("a: T")
|
||||
.Input("b: T")
|
||||
.Input("args: num_args * T")
|
||||
.Input("mkl_a: uint8")
|
||||
.Input("mkl_b: uint8")
|
||||
.Input("mkl_args: num_args * uint8")
|
||||
.Output("product: T")
|
||||
.Output("mkl_product: uint8")
|
||||
.Attr("transpose_a: bool = false")
|
||||
.Attr("transpose_b: bool = false")
|
||||
.Attr("T: {float}")
|
||||
.Attr("num_args: int >= 0")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ----------- //
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
// --------------------------------------------- //
|
||||
.SetShapeFn(shape_inference::MatMulShape)
|
||||
.Doc(R"doc(
|
||||
MKL version of FusedMatMul operator. Uses MKL-DNN APIs to implement MatMul
|
||||
operator.
|
||||
|
||||
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("__MklDummyPadWithFusedConv2D")
|
||||
.Input("input: T")
|
||||
.Input("filter: T")
|
||||
|
@ -482,7 +482,7 @@ class MklDnnShape {
|
||||
/// We use lazy evaluation and create it only when needed. Input format can
|
||||
/// also be Blocked format.
|
||||
inline void SetTfLayout(size_t dims, const memory::dims& sizes,
|
||||
MKL_TENSOR_FORMAT format, bool is_2d = false) {
|
||||
MKL_TENSOR_FORMAT format) {
|
||||
DCHECK_EQ(dims, sizes.size())
|
||||
<< "SetTfLayout: Number of dimensions does not"
|
||||
"match with dimension array";
|
||||
@ -492,7 +492,7 @@ class MklDnnShape {
|
||||
}
|
||||
data_.tf_data_format_ = format;
|
||||
if (format != MKL_TENSOR_FORMAT_BLOCKED) {
|
||||
if (is_2d) {
|
||||
if (dims == 2) {
|
||||
data_.map_[0] = MklDnnDims::Dim_N;
|
||||
data_.map_[1] = MklDnnDims::Dim_C;
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user