Merge pull request #45358 from Intel-tensorflow:yanzhang/matmul_biasadd_add_fusion

PiperOrigin-RevId: 347408532
Change-Id: I6e8b12dfef056f095449fd70ce387b79d3e8b4d7
This commit is contained in:
TensorFlower Gardener 2020-12-14 10:05:53 -08:00
commit cd052fa5f0
5 changed files with 292 additions and 97 deletions

View File

@ -446,6 +446,84 @@ TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
} }
} }
} }
TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({4, 32});
auto input_shape_add = ops::Placeholder::Shape({4, 8});
auto filter_shape = ops::Placeholder::Shape({32, 8});
auto bias_shape = ops::Placeholder::Shape({8});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_add =
Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
auto fetch = s.WithOpName("fetch");
auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
ops::Identity(fetch, add);
auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape.shape_.dim_sizes()));
auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape_add.shape_.dim_sizes()));
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(filter_shape.shape_.dim_sizes()));
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(bias_shape.shape_.dim_sizes()));
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_tensor},
{"filter", filter_tensor},
{"bias", bias_tensor},
{"input_add", input_add_tensor}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
auto fetch_node_name = "add";
if (node.name() == fetch_node_name) {
EXPECT_EQ("_FusedMatMul", node.op());
EXPECT_EQ("input", node.input(0));
EXPECT_EQ("filter", node.input(1));
EXPECT_EQ(2, node.attr().at("num_args").i());
EXPECT_EQ("bias", node.input(2));
EXPECT_EQ("input_add", node.input(3));
const auto fused_ops = node.attr().at("fused_ops").list().s();
EXPECT_EQ(2, fused_ops.size());
EXPECT_EQ("BiasAdd", fused_ops[0]);
EXPECT_EQ("Add", fused_ops[1]);
found++;
}
}
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
}
#endif // ENABLE_MKLDNN_V1 #endif // ENABLE_MKLDNN_V1
} // namespace grappler } // namespace grappler

View File

@ -1283,28 +1283,36 @@ Status AddFusedContractionNode(RemapperContext* ctx,
const NodeDef& contraction = graph->node(matched.contraction); const NodeDef& contraction = graph->node(matched.contraction);
const NodeDef& bias_add = graph->node(matched.bias_add); const NodeDef& bias_add = graph->node(matched.bias_add);
// MKL version only support fusion for Conv2D // MKL version only support fusion for Conv2D and MatMul
DCHECK(IsConv2D(contraction)); DCHECK(IsConv2D(contraction) || IsMatMul(contraction));
NodeDef fused_conv2d; NodeDef contraction_node;
const NodeDef& add = graph->node(matched.add); const NodeDef& add = graph->node(matched.add);
fused_conv2d.set_name(add.name()); contraction_node.set_name(add.name());
fused_conv2d.set_op(kFusedConv2D); contraction_node.set_device(contraction.device());
fused_conv2d.set_device(contraction.device()); contraction_node.add_input(
fused_conv2d.add_input(contraction.input(0)); // 0: input contraction.input(0)); // 0: input(conv) / a (matmul)
fused_conv2d.add_input(contraction.input(1)); // 1: filter contraction_node.add_input(
fused_conv2d.add_input(bias_add.input(1)); // 2: bias contraction.input(1)); // 1: filter(conv) / b (matmul)
contraction_node.add_input(bias_add.input(1)); // 2: bias
// Add OP has two inputs, one is conv+bias pattern matched previously, // Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
// the other input to add is fused here. // previously, the other input to add is fused here.
fused_conv2d.add_input(add.input(1 - matched.port_id)); contraction_node.add_input(add.input(1 - matched.port_id));
CopyConv2DAttributes(contraction, &fused_conv2d); if (IsConv2D(contraction)) {
SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add"}, 2); contraction_node.set_op(kFusedConv2D);
CopyConv2DAttributes(contraction, &contraction_node);
} else if (IsMatMul(contraction)) {
contraction_node.set_op(kFusedMatMul);
CopyMatMulAttributes(contraction, &contraction_node);
}
SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2);
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status; Status status;
mutation->AddNode(std::move(fused_conv2d), &status); mutation->AddNode(std::move(contraction_node), &status);
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
TF_RETURN_IF_ERROR(mutation->Apply()); TF_RETURN_IF_ERROR(mutation->Apply());
@ -1621,19 +1629,25 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
} }
#ifdef INTEL_MKL #ifdef INTEL_MKL
bool IsConv2DWithAdd(const RemapperContext& ctx, int node_index) { bool IsConv2DOrMatMul(const NodeDef& node) {
return IsConv2D(node) || IsMatMul(node);
}
bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
const auto* node_view = ctx.graph_view.GetNode(node_index); const auto* node_view = ctx.graph_view.GetNode(node_index);
const auto* node_def = node_view->node(); const auto* node_def = node_view->node();
// Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion. // Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
// MatMul + Add or MatMul + BiasAdd + Add fusion.
auto is_supported_add_input = [](const auto* node_view) -> bool { auto is_supported_add_input = [](const auto* node_view) -> bool {
if (IsConv2D(*node_view->node())) return true; // Currently only support Conv2D and MatMul
if (IsConv2DOrMatMul(*node_view->node())) return true;
if (IsBiasAdd(*node_view->node())) { if (IsBiasAdd(*node_view->node())) {
if (node_view->NumRegularFanins() < 2) return false; if (node_view->NumRegularFanins() < 2) return false;
const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0); const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1); const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
return IsConv2D(*bias_add_fanin_0.node_view()->node()) || return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
IsConv2D(*bias_add_fanin_1.node_view()->node()); IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node());
} }
return false; return false;
}; };
@ -1739,7 +1753,7 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
#ifdef INTEL_MKL #ifdef INTEL_MKL
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() || return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
IsConv2DWithAdd(ctx, node_index); IsContractionWithAdd(ctx, node_index);
#else #else
return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() || return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
is_batch_norm_fusion_candidate(); is_batch_norm_fusion_candidate();

View File

@ -955,6 +955,71 @@ TEST_F(FilterCacheTest, Conv2DFilterCacheTest) {
// Testing fusion of MatMul and BiasAdd // Testing fusion of MatMul and BiasAdd
template <typename T> template <typename T>
class MklFusedMatMulOpTest : public OpsTestBase { class MklFusedMatMulOpTest : public OpsTestBase {
private:
void RunMklFusedMatMulOp(const Tensor& input, const Tensor& weight,
const std::vector<Tensor>& args,
const std::vector<string>& fused_ops,
Tensor* output) {
DataType dtype = DataTypeToEnum<T>::v();
const int num_args = args.size();
if (!NativeFormatEnabled()) {
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(num_args, DT_UINT8))
.Attr("T", dtype)
.Attr("transpose_a", false)
.Attr("transpose_b", false)
.Attr("num_args", num_args)
.Attr("fused_ops", fused_ops)
.Attr("epsilon", 0.0001)
.Attr("_kernel", "MklLayoutDependentOp")
.Finalize(node_def()));
} else {
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Attr("T", dtype)
.Attr("transpose_a", false)
.Attr("transpose_b", false)
.Attr("num_args", num_args)
.Attr("fused_ops", fused_ops)
.Attr("epsilon", 0.0001)
.Attr("_kernel", "MklNameChangeOp")
.Finalize(node_def()));
}
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(input.shape(), input.flat<T>());
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
for (const Tensor& arg : args)
AddInputFromArray<T>(arg.shape(), arg.flat<T>());
if (!NativeFormatEnabled()) {
// Add MKL meta input for input, filter and bias.
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
for (const Tensor& arg : args)
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
}
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_tensor = *GetOutput(0);
if (!NativeFormatEnabled()) {
const Tensor& output_meta_tensor = *GetOutput(1);
CommonTestUtilities<T> test_util;
test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
output);
} else {
*output = output_tensor;
}
}
protected: protected:
void VerifyFusedMatMul(const int kBatch, const int kInputChannel, void VerifyFusedMatMul(const int kBatch, const int kInputChannel,
const int kOutputChannel, const int kOutputChannel,
@ -1002,70 +1067,24 @@ class MklFusedMatMulOpTest : public OpsTestBase {
next_op = ops::Tanh(root.WithOpName(last_op), next_op); next_op = ops::Tanh(root.WithOpName(last_op), next_op);
} }
if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
fused_ops.end()) {
last_op = "with_add";
next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op);
}
CommonTestUtilities<T>::RunAndFetch(root, last_op, output); CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
}; };
const FusedGraphRunner run_fused = const FusedGraphRunner run_fused =
[this](const Tensor& input, const Tensor& weight, const Tensor& bias, [this](const Tensor& input, const Tensor& weight, const Tensor& bias,
const std::vector<string>& fused_ops, Tensor* output) { const std::vector<string>& fused_ops, Tensor* output) {
DataType dtype = DataTypeToEnum<T>::v(); std::vector<Tensor> fused_input = {bias};
const int num_args = 1; if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
fused_ops.end()) {
if (!NativeFormatEnabled()) { fused_input.push_back(input);
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(num_args, DT_UINT8))
.Attr("T", dtype)
.Attr("transpose_a", false)
.Attr("transpose_b", false)
.Attr("num_args", num_args)
.Attr("fused_ops", fused_ops)
.Attr("epsilon", 0.0001)
.Attr("_kernel", "MklLayoutDependentOp")
.Finalize(node_def()));
} else {
TF_EXPECT_OK(
NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Attr("T", dtype)
.Attr("transpose_a", false)
.Attr("transpose_b", false)
.Attr("num_args", num_args)
.Attr("fused_ops", fused_ops)
.Attr("epsilon", 0.0001)
.Attr("_kernel", "MklNameChangeOp")
.Finalize(node_def()));
}
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(input.shape(), input.flat<T>());
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
AddInputFromArray<T>(bias.shape(), bias.flat<T>());
if (!NativeFormatEnabled()) {
// Add MKL meta input for input, filter and bias.
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
}
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_tensor = *GetOutput(0);
if (!NativeFormatEnabled()) {
const Tensor& output_meta_tensor = *GetOutput(1);
CommonTestUtilities<T> test_util;
test_util.PerformConversion(dtype, output_tensor,
output_meta_tensor, output);
} else {
*output = output_tensor;
} }
RunMklFusedMatMulOp(input, weight, fused_input, fused_ops, output);
}; };
CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch, CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch,
@ -1120,12 +1139,22 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndTanh) {
{"BiasAdd", "Tanh"}); {"BiasAdd", "Tanh"});
} }
TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) {
const int batch = 3;
const int input_channel = 4;
const int output_channel = 4;
this->VerifyFusedMatMul(batch, input_channel, output_channel,
{"BiasAdd", "Add"});
}
REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, // REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, //
WithBias, // WithBias, //
WithBiasAndRelu, // WithBiasAndRelu, //
WithBiasAndRelu6, // WithBiasAndRelu6, //
WithBiasAndElu, // WithBiasAndElu, //
WithBiasAndTanh); WithBiasAndTanh, //
WithBiasAndAdd);
using MklFusedMatMulDataTypes = ::testing::Types<float>; using MklFusedMatMulDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest, INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,

View File

@ -45,6 +45,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
ctx, fused_ops_[0] == "BiasAdd", ctx, fused_ops_[0] == "BiasAdd",
errors::InvalidArgument( errors::InvalidArgument(
"The 1st post-argument of MklFusedMatMul must be BiasAdd.")); "The 1st post-argument of MklFusedMatMul must be BiasAdd."));
if (fused_ops_.size() > 1 && fused_ops_[1] == "Add") fuse_add_ = true;
OP_REQUIRES( OP_REQUIRES(
ctx, transpose_a_ == false, ctx, transpose_a_ == false,
errors::InvalidArgument("In[0] of MklMatMul can't be transposed.")); errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
@ -114,7 +115,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
// 2. var, keep the original format to avoid reordering. // 2. var, keep the original format to avoid reordering.
MklDnnMatMulFwdParams matmul_params( MklDnnMatMulFwdParams matmul_params(
src_dims, weight_dims, bias_dims, dst_dims, src_format, src_dims, weight_dims, bias_dims, dst_dims, src_format,
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format); (this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format,
MEMORY_FORMAT::nc);
// Extend the basic parameters for data types and fusions. // Extend the basic parameters for data types and fusions.
ExtendMklDnnMatMulFwdParams(ctx, matmul_params); ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
@ -126,15 +128,70 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd = std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd =
matmul_prim->GetPrimitiveDesc(); matmul_prim->GetPrimitiveDesc();
if (src_mkl_shape.IsMklTensor()) { // The output shape of MatMul is same both for MKL and TF version.
this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims, // They are all NC format, no matter what's the format of input.
MKL_TENSOR_FORMAT_NC, &dst_tensor); // And the shape of AddOp is also the same with output's shape.
auto dst_pd = matmul_pd->PRIMITIVE_DESC_DST;
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);
TensorShape output_tf_shape({batch, channel});
if (fuse_add_) {
const Tensor& add_tensor = MklGetInput(ctx, kInputIndex_Add);
MklDnnShape add_mkl_shape;
GetMklShape(ctx, kInputIndex_Add, &add_mkl_shape, native_format);
// For native format, we need not to set metadata.
if (native_format && ctx->forward_input_to_output_with_shape(
kInputIndex_Add, kOutputIndex_Dst,
output_tf_shape, &dst_tensor)) {
; // Need to do nothing for native format
} else if (!native_format && ForwardMklTensorInToOutWithMklShape(
ctx, kInputIndex_Add, kOutputIndex_Dst,
&dst_tensor, output_mkl_shape, false)) {
; // If it's not native format, need to forward and set meta first
} else {
// If forward is not successful, we should use reorder to copy add
// tensor to dst tensor
AllocateOutputSetMklShape(ctx, kOutputIndex_Dst, &dst_tensor,
output_tf_shape, output_mkl_shape,
native_format);
auto output_format_tag =
MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC);
auto add_md =
add_mkl_shape.IsMklTensor()
? add_mkl_shape.GetMklLayout()
: memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
auto dst_md =
memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
void* add_buf =
static_cast<void*>(const_cast<T*>(add_tensor.flat<T>().data()));
void* dst_buf = static_cast<void*>((dst_tensor)->flat<T>().data());
if (native_format) {
// We are simply deep copying the add_tensor to dst_tensor without
// changing memory layout, hence using same memory descriptor.
add_md = dst_md =
memory::desc({add_tensor.NumElements()}, MklDnnType<T>(),
mkldnn::memory::format_tag::x);
}
auto fuse_add_src_ =
MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
auto fuse_add_dst_ =
MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
auto reorder_desc =
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_,
this->cpu_engine_, ctx);
}
} else { } else {
TensorShape dst_tensor_shape({batch, channel}); AllocateOutputSetMklShape(ctx, 0, &dst_tensor, output_tf_shape,
MklDnnShape dst_mkl_shape; output_mkl_shape, native_format);
dst_mkl_shape.SetMklTensor(false);
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape,
dst_mkl_shape, native_format);
} }
// if there's nothing to compute, just return. // if there's nothing to compute, just return.
@ -228,6 +285,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}}); params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
} else if (post_op == "Tanh") { } else if (post_op == "Tanh") {
params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}}); params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
} else if (post_op == "Add") {
params.post_op_params.push_back({"sum", {1.0}});
} else { } else {
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, errors::InvalidArgument( ctx, errors::InvalidArgument(
@ -237,10 +296,13 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
} }
private: private:
bool fuse_add_ = false;
bool transpose_a_; bool transpose_a_;
bool transpose_b_; bool transpose_b_;
std::vector<string> fused_ops_; std::vector<string> fused_ops_;
}; const int kInputIndex_Add = 3;
const int kOutputIndex_Dst = 0;
}; // namespace tensorflow
// Register mkl kernels for supported operations and types. // Register mkl kernels for supported operations and types.
#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \ #define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \

View File

@ -48,8 +48,9 @@ struct MklDnnMatMulFwdParams {
memory::dims weight_dims; memory::dims weight_dims;
memory::dims bias_dims; memory::dims bias_dims;
memory::dims dst_dims; memory::dims dst_dims;
memory::format_tag src_format; MEMORY_FORMAT src_format;
memory::format_tag weight_format; MEMORY_FORMAT weight_format;
MEMORY_FORMAT dst_format;
string dtypes = string(""); string dtypes = string("");
struct PostOpParam { struct PostOpParam {
string name; string name;
@ -57,17 +58,18 @@ struct MklDnnMatMulFwdParams {
}; };
std::vector<PostOpParam> post_op_params; std::vector<PostOpParam> post_op_params;
MklDnnMatMulFwdParams( MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims, memory::dims bias_dims, memory::dims dst_dims,
memory::dims dst_dims, MEMORY_FORMAT src_format = MEMORY_FORMAT::any,
memory::format_tag src_format = memory::format_tag::any, MEMORY_FORMAT weight_format = MEMORY_FORMAT::any,
memory::format_tag weight_format = memory::format_tag::any) MEMORY_FORMAT dst_format = MEMORY_FORMAT::any)
: src_dims(src_dims), : src_dims(src_dims),
weight_dims(weight_dims), weight_dims(weight_dims),
bias_dims(bias_dims), bias_dims(bias_dims),
dst_dims(dst_dims), dst_dims(dst_dims),
src_format(src_format), src_format(src_format),
weight_format(weight_format) {} weight_format(weight_format),
dst_format(dst_format) {}
}; };
// With quantization, input, weight, bias, and output can have different types. // With quantization, input, weight, bias, and output can have different types.
@ -184,7 +186,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims}, context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
MklDnnType<Toutput>(), MklDnnType<Toutput>(),
memory::format_tag::any)); matmul_fwd_params.dst_format));
context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims}, context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
MklDnnType<Tbias>(), MklDnnType<Tbias>(),
@ -236,11 +238,17 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
std::vector<float> scales; std::vector<float> scales;
scales.push_back(post_op_param.param[0]); scales.push_back(post_op_param.param[0]);
post_ops_attr.set_output_scales(0, scales); post_ops_attr.set_output_scales(0, scales);
} else if (post_op_param.name == "sum") {
DCHECK_EQ(post_op_param.param.size(), 1);
float op_scale = post_op_param.param[0];
post_ops.append_sum(op_scale);
} else { } else {
DCHECK((post_op_param.name == "relu") || DCHECK((post_op_param.name == "relu") ||
(post_op_param.name == "relu6") || (post_op_param.name == "relu6") ||
(post_op_param.name == "elu") || (post_op_param.name == "elu") ||
(post_op_param.name == "tanh") || (post_op_param.name == "tanh") ||
(post_op_param.name == "sum") ||
(post_op_param.name == "output_scale")); (post_op_param.name == "output_scale"));
} }
} }
@ -340,6 +348,10 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey(post_op_param.param[0]); key_creator.AddAsKey(post_op_param.param[0]);
key_creator.AddAsKey(post_op_param.param[1]); key_creator.AddAsKey(post_op_param.param[1]);
key_creator.AddAsKey(post_op_param.param[2]); key_creator.AddAsKey(post_op_param.param[2]);
} else if (post_op_param.name == "sum") {
DCHECK_EQ(post_op_param.param.size(), 1);
key_creator.AddAsKey(post_op_param.name);
key_creator.AddAsKey(post_op_param.param[0]);
} else if (post_op_param.name == "output_scale") { } else if (post_op_param.name == "output_scale") {
DCHECK_EQ(post_op_param.param.size(), 1); DCHECK_EQ(post_op_param.param.size(), 1);
key_creator.AddAsKey(post_op_param.name); key_creator.AddAsKey(post_op_param.name);