Merge pull request #45358 from Intel-tensorflow:yanzhang/matmul_biasadd_add_fusion
PiperOrigin-RevId: 347408532 Change-Id: I6e8b12dfef056f095449fd70ce387b79d3e8b4d7
This commit is contained in:
commit
cd052fa5f0
@ -446,6 +446,84 @@ TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
|
||||||
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
auto input_shape = ops::Placeholder::Shape({4, 32});
|
||||||
|
auto input_shape_add = ops::Placeholder::Shape({4, 8});
|
||||||
|
auto filter_shape = ops::Placeholder::Shape({32, 8});
|
||||||
|
auto bias_shape = ops::Placeholder::Shape({8});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||||
|
auto input_add =
|
||||||
|
Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
|
||||||
|
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||||
|
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||||
|
|
||||||
|
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
|
||||||
|
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||||
|
|
||||||
|
auto fetch = s.WithOpName("fetch");
|
||||||
|
auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
|
||||||
|
|
||||||
|
ops::Identity(fetch, add);
|
||||||
|
|
||||||
|
auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||||
|
TensorShape(input_shape.shape_.dim_sizes()));
|
||||||
|
auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||||
|
TensorShape(input_shape_add.shape_.dim_sizes()));
|
||||||
|
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||||
|
TensorShape(filter_shape.shape_.dim_sizes()));
|
||||||
|
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||||
|
TensorShape(bias_shape.shape_.dim_sizes()));
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
item.feed = {{"input", input_tensor},
|
||||||
|
{"filter", filter_tensor},
|
||||||
|
{"bias", bias_tensor},
|
||||||
|
{"input_add", input_add_tensor}};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
// Place all nodes on CPU.
|
||||||
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
|
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
Remapper optimizer(RewriterConfig::AGGRESSIVE);
|
||||||
|
GraphDef output;
|
||||||
|
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
auto fetch_node_name = "add";
|
||||||
|
if (node.name() == fetch_node_name) {
|
||||||
|
EXPECT_EQ("_FusedMatMul", node.op());
|
||||||
|
EXPECT_EQ("input", node.input(0));
|
||||||
|
EXPECT_EQ("filter", node.input(1));
|
||||||
|
|
||||||
|
EXPECT_EQ(2, node.attr().at("num_args").i());
|
||||||
|
EXPECT_EQ("bias", node.input(2));
|
||||||
|
EXPECT_EQ("input_add", node.input(3));
|
||||||
|
|
||||||
|
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||||
|
EXPECT_EQ(2, fused_ops.size());
|
||||||
|
EXPECT_EQ("BiasAdd", fused_ops[0]);
|
||||||
|
EXPECT_EQ("Add", fused_ops[1]);
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(1, found);
|
||||||
|
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
|
EXPECT_EQ(1, tensors_expected.size());
|
||||||
|
EXPECT_EQ(1, tensors.size());
|
||||||
|
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
|
||||||
|
}
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
|
@ -1283,28 +1283,36 @@ Status AddFusedContractionNode(RemapperContext* ctx,
|
|||||||
const NodeDef& contraction = graph->node(matched.contraction);
|
const NodeDef& contraction = graph->node(matched.contraction);
|
||||||
const NodeDef& bias_add = graph->node(matched.bias_add);
|
const NodeDef& bias_add = graph->node(matched.bias_add);
|
||||||
|
|
||||||
// MKL version only support fusion for Conv2D
|
// MKL version only support fusion for Conv2D and MatMul
|
||||||
DCHECK(IsConv2D(contraction));
|
DCHECK(IsConv2D(contraction) || IsMatMul(contraction));
|
||||||
|
|
||||||
NodeDef fused_conv2d;
|
NodeDef contraction_node;
|
||||||
const NodeDef& add = graph->node(matched.add);
|
const NodeDef& add = graph->node(matched.add);
|
||||||
fused_conv2d.set_name(add.name());
|
contraction_node.set_name(add.name());
|
||||||
fused_conv2d.set_op(kFusedConv2D);
|
contraction_node.set_device(contraction.device());
|
||||||
fused_conv2d.set_device(contraction.device());
|
contraction_node.add_input(
|
||||||
fused_conv2d.add_input(contraction.input(0)); // 0: input
|
contraction.input(0)); // 0: input(conv) / a (matmul)
|
||||||
fused_conv2d.add_input(contraction.input(1)); // 1: filter
|
contraction_node.add_input(
|
||||||
fused_conv2d.add_input(bias_add.input(1)); // 2: bias
|
contraction.input(1)); // 1: filter(conv) / b (matmul)
|
||||||
|
contraction_node.add_input(bias_add.input(1)); // 2: bias
|
||||||
|
|
||||||
// Add OP has two inputs, one is conv+bias pattern matched previously,
|
// Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
|
||||||
// the other input to add is fused here.
|
// previously, the other input to add is fused here.
|
||||||
fused_conv2d.add_input(add.input(1 - matched.port_id));
|
contraction_node.add_input(add.input(1 - matched.port_id));
|
||||||
|
|
||||||
CopyConv2DAttributes(contraction, &fused_conv2d);
|
if (IsConv2D(contraction)) {
|
||||||
SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add"}, 2);
|
contraction_node.set_op(kFusedConv2D);
|
||||||
|
CopyConv2DAttributes(contraction, &contraction_node);
|
||||||
|
} else if (IsMatMul(contraction)) {
|
||||||
|
contraction_node.set_op(kFusedMatMul);
|
||||||
|
CopyMatMulAttributes(contraction, &contraction_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2);
|
||||||
|
|
||||||
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
||||||
Status status;
|
Status status;
|
||||||
mutation->AddNode(std::move(fused_conv2d), &status);
|
mutation->AddNode(std::move(contraction_node), &status);
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
TF_RETURN_IF_ERROR(mutation->Apply());
|
TF_RETURN_IF_ERROR(mutation->Apply());
|
||||||
|
|
||||||
@ -1621,19 +1629,25 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
bool IsConv2DWithAdd(const RemapperContext& ctx, int node_index) {
|
bool IsConv2DOrMatMul(const NodeDef& node) {
|
||||||
|
return IsConv2D(node) || IsMatMul(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
|
||||||
const auto* node_view = ctx.graph_view.GetNode(node_index);
|
const auto* node_view = ctx.graph_view.GetNode(node_index);
|
||||||
const auto* node_def = node_view->node();
|
const auto* node_def = node_view->node();
|
||||||
|
|
||||||
// Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
|
// Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
|
||||||
|
// MatMul + Add or MatMul + BiasAdd + Add fusion.
|
||||||
auto is_supported_add_input = [](const auto* node_view) -> bool {
|
auto is_supported_add_input = [](const auto* node_view) -> bool {
|
||||||
if (IsConv2D(*node_view->node())) return true;
|
// Currently only support Conv2D and MatMul
|
||||||
|
if (IsConv2DOrMatMul(*node_view->node())) return true;
|
||||||
if (IsBiasAdd(*node_view->node())) {
|
if (IsBiasAdd(*node_view->node())) {
|
||||||
if (node_view->NumRegularFanins() < 2) return false;
|
if (node_view->NumRegularFanins() < 2) return false;
|
||||||
const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
|
const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
|
||||||
const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
|
const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
|
||||||
return IsConv2D(*bias_add_fanin_0.node_view()->node()) ||
|
return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
|
||||||
IsConv2D(*bias_add_fanin_1.node_view()->node());
|
IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node());
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
@ -1739,7 +1753,7 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
|||||||
|
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
|
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
|
||||||
IsConv2DWithAdd(ctx, node_index);
|
IsContractionWithAdd(ctx, node_index);
|
||||||
#else
|
#else
|
||||||
return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
|
return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
|
||||||
is_batch_norm_fusion_candidate();
|
is_batch_norm_fusion_candidate();
|
||||||
|
@ -955,6 +955,71 @@ TEST_F(FilterCacheTest, Conv2DFilterCacheTest) {
|
|||||||
// Testing fusion of MatMul and BiasAdd
|
// Testing fusion of MatMul and BiasAdd
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class MklFusedMatMulOpTest : public OpsTestBase {
|
class MklFusedMatMulOpTest : public OpsTestBase {
|
||||||
|
private:
|
||||||
|
void RunMklFusedMatMulOp(const Tensor& input, const Tensor& weight,
|
||||||
|
const std::vector<Tensor>& args,
|
||||||
|
const std::vector<string>& fused_ops,
|
||||||
|
Tensor* output) {
|
||||||
|
DataType dtype = DataTypeToEnum<T>::v();
|
||||||
|
const int num_args = args.size();
|
||||||
|
if (!NativeFormatEnabled()) {
|
||||||
|
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
|
||||||
|
.Input(FakeInput(dtype))
|
||||||
|
.Input(FakeInput(dtype))
|
||||||
|
.Input(FakeInput(num_args, dtype))
|
||||||
|
.Input(FakeInput(DT_UINT8))
|
||||||
|
.Input(FakeInput(DT_UINT8))
|
||||||
|
.Input(FakeInput(num_args, DT_UINT8))
|
||||||
|
.Attr("T", dtype)
|
||||||
|
.Attr("transpose_a", false)
|
||||||
|
.Attr("transpose_b", false)
|
||||||
|
.Attr("num_args", num_args)
|
||||||
|
.Attr("fused_ops", fused_ops)
|
||||||
|
.Attr("epsilon", 0.0001)
|
||||||
|
.Attr("_kernel", "MklLayoutDependentOp")
|
||||||
|
.Finalize(node_def()));
|
||||||
|
} else {
|
||||||
|
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
|
||||||
|
.Input(FakeInput(dtype))
|
||||||
|
.Input(FakeInput(dtype))
|
||||||
|
.Input(FakeInput(num_args, dtype))
|
||||||
|
.Attr("T", dtype)
|
||||||
|
.Attr("transpose_a", false)
|
||||||
|
.Attr("transpose_b", false)
|
||||||
|
.Attr("num_args", num_args)
|
||||||
|
.Attr("fused_ops", fused_ops)
|
||||||
|
.Attr("epsilon", 0.0001)
|
||||||
|
.Attr("_kernel", "MklNameChangeOp")
|
||||||
|
.Finalize(node_def()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_EXPECT_OK(InitOp());
|
||||||
|
|
||||||
|
AddInputFromArray<T>(input.shape(), input.flat<T>());
|
||||||
|
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
|
||||||
|
for (const Tensor& arg : args)
|
||||||
|
AddInputFromArray<T>(arg.shape(), arg.flat<T>());
|
||||||
|
if (!NativeFormatEnabled()) {
|
||||||
|
// Add MKL meta input for input, filter and bias.
|
||||||
|
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||||
|
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||||
|
for (const Tensor& arg : args)
|
||||||
|
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
const Tensor& output_tensor = *GetOutput(0);
|
||||||
|
if (!NativeFormatEnabled()) {
|
||||||
|
const Tensor& output_meta_tensor = *GetOutput(1);
|
||||||
|
CommonTestUtilities<T> test_util;
|
||||||
|
test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
|
||||||
|
output);
|
||||||
|
} else {
|
||||||
|
*output = output_tensor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void VerifyFusedMatMul(const int kBatch, const int kInputChannel,
|
void VerifyFusedMatMul(const int kBatch, const int kInputChannel,
|
||||||
const int kOutputChannel,
|
const int kOutputChannel,
|
||||||
@ -1002,70 +1067,24 @@ class MklFusedMatMulOpTest : public OpsTestBase {
|
|||||||
next_op = ops::Tanh(root.WithOpName(last_op), next_op);
|
next_op = ops::Tanh(root.WithOpName(last_op), next_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
|
||||||
|
fused_ops.end()) {
|
||||||
|
last_op = "with_add";
|
||||||
|
next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op);
|
||||||
|
}
|
||||||
|
|
||||||
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
|
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
|
||||||
};
|
};
|
||||||
|
|
||||||
const FusedGraphRunner run_fused =
|
const FusedGraphRunner run_fused =
|
||||||
[this](const Tensor& input, const Tensor& weight, const Tensor& bias,
|
[this](const Tensor& input, const Tensor& weight, const Tensor& bias,
|
||||||
const std::vector<string>& fused_ops, Tensor* output) {
|
const std::vector<string>& fused_ops, Tensor* output) {
|
||||||
DataType dtype = DataTypeToEnum<T>::v();
|
std::vector<Tensor> fused_input = {bias};
|
||||||
const int num_args = 1;
|
if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
|
||||||
|
fused_ops.end()) {
|
||||||
if (!NativeFormatEnabled()) {
|
fused_input.push_back(input);
|
||||||
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
|
|
||||||
.Input(FakeInput(dtype))
|
|
||||||
.Input(FakeInput(dtype))
|
|
||||||
.Input(FakeInput(num_args, dtype))
|
|
||||||
.Input(FakeInput(DT_UINT8))
|
|
||||||
.Input(FakeInput(DT_UINT8))
|
|
||||||
.Input(FakeInput(num_args, DT_UINT8))
|
|
||||||
.Attr("T", dtype)
|
|
||||||
.Attr("transpose_a", false)
|
|
||||||
.Attr("transpose_b", false)
|
|
||||||
.Attr("num_args", num_args)
|
|
||||||
.Attr("fused_ops", fused_ops)
|
|
||||||
.Attr("epsilon", 0.0001)
|
|
||||||
.Attr("_kernel", "MklLayoutDependentOp")
|
|
||||||
.Finalize(node_def()));
|
|
||||||
} else {
|
|
||||||
TF_EXPECT_OK(
|
|
||||||
NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
|
|
||||||
.Input(FakeInput(dtype))
|
|
||||||
.Input(FakeInput(dtype))
|
|
||||||
.Input(FakeInput(num_args, dtype))
|
|
||||||
.Attr("T", dtype)
|
|
||||||
.Attr("transpose_a", false)
|
|
||||||
.Attr("transpose_b", false)
|
|
||||||
.Attr("num_args", num_args)
|
|
||||||
.Attr("fused_ops", fused_ops)
|
|
||||||
.Attr("epsilon", 0.0001)
|
|
||||||
.Attr("_kernel", "MklNameChangeOp")
|
|
||||||
.Finalize(node_def()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_EXPECT_OK(InitOp());
|
|
||||||
|
|
||||||
AddInputFromArray<T>(input.shape(), input.flat<T>());
|
|
||||||
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
|
|
||||||
AddInputFromArray<T>(bias.shape(), bias.flat<T>());
|
|
||||||
if (!NativeFormatEnabled()) {
|
|
||||||
// Add MKL meta input for input, filter and bias.
|
|
||||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
|
||||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
|
||||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ASSERT_OK(RunOpKernel());
|
|
||||||
|
|
||||||
const Tensor& output_tensor = *GetOutput(0);
|
|
||||||
if (!NativeFormatEnabled()) {
|
|
||||||
const Tensor& output_meta_tensor = *GetOutput(1);
|
|
||||||
CommonTestUtilities<T> test_util;
|
|
||||||
test_util.PerformConversion(dtype, output_tensor,
|
|
||||||
output_meta_tensor, output);
|
|
||||||
} else {
|
|
||||||
*output = output_tensor;
|
|
||||||
}
|
}
|
||||||
|
RunMklFusedMatMulOp(input, weight, fused_input, fused_ops, output);
|
||||||
};
|
};
|
||||||
|
|
||||||
CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch,
|
CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch,
|
||||||
@ -1120,12 +1139,22 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndTanh) {
|
|||||||
{"BiasAdd", "Tanh"});
|
{"BiasAdd", "Tanh"});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) {
|
||||||
|
const int batch = 3;
|
||||||
|
const int input_channel = 4;
|
||||||
|
const int output_channel = 4;
|
||||||
|
|
||||||
|
this->VerifyFusedMatMul(batch, input_channel, output_channel,
|
||||||
|
{"BiasAdd", "Add"});
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, //
|
REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, //
|
||||||
WithBias, //
|
WithBias, //
|
||||||
WithBiasAndRelu, //
|
WithBiasAndRelu, //
|
||||||
WithBiasAndRelu6, //
|
WithBiasAndRelu6, //
|
||||||
WithBiasAndElu, //
|
WithBiasAndElu, //
|
||||||
WithBiasAndTanh);
|
WithBiasAndTanh, //
|
||||||
|
WithBiasAndAdd);
|
||||||
|
|
||||||
using MklFusedMatMulDataTypes = ::testing::Types<float>;
|
using MklFusedMatMulDataTypes = ::testing::Types<float>;
|
||||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,
|
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,
|
||||||
|
@ -45,6 +45,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
ctx, fused_ops_[0] == "BiasAdd",
|
ctx, fused_ops_[0] == "BiasAdd",
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"The 1st post-argument of MklFusedMatMul must be BiasAdd."));
|
"The 1st post-argument of MklFusedMatMul must be BiasAdd."));
|
||||||
|
if (fused_ops_.size() > 1 && fused_ops_[1] == "Add") fuse_add_ = true;
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, transpose_a_ == false,
|
ctx, transpose_a_ == false,
|
||||||
errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
|
errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
|
||||||
@ -114,7 +115,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
// 2. var, keep the original format to avoid reordering.
|
// 2. var, keep the original format to avoid reordering.
|
||||||
MklDnnMatMulFwdParams matmul_params(
|
MklDnnMatMulFwdParams matmul_params(
|
||||||
src_dims, weight_dims, bias_dims, dst_dims, src_format,
|
src_dims, weight_dims, bias_dims, dst_dims, src_format,
|
||||||
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format);
|
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format,
|
||||||
|
MEMORY_FORMAT::nc);
|
||||||
|
|
||||||
// Extend the basic parameters for data types and fusions.
|
// Extend the basic parameters for data types and fusions.
|
||||||
ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
|
ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
|
||||||
@ -126,15 +128,70 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd =
|
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd =
|
||||||
matmul_prim->GetPrimitiveDesc();
|
matmul_prim->GetPrimitiveDesc();
|
||||||
|
|
||||||
if (src_mkl_shape.IsMklTensor()) {
|
// The output shape of MatMul is same both for MKL and TF version.
|
||||||
this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims,
|
// They are all NC format, no matter what's the format of input.
|
||||||
MKL_TENSOR_FORMAT_NC, &dst_tensor);
|
// And the shape of AddOp is also the same with output's shape.
|
||||||
|
auto dst_pd = matmul_pd->PRIMITIVE_DESC_DST;
|
||||||
|
|
||||||
|
MklDnnShape output_mkl_shape;
|
||||||
|
output_mkl_shape.SetMklTensor(false);
|
||||||
|
|
||||||
|
TensorShape output_tf_shape({batch, channel});
|
||||||
|
|
||||||
|
if (fuse_add_) {
|
||||||
|
const Tensor& add_tensor = MklGetInput(ctx, kInputIndex_Add);
|
||||||
|
MklDnnShape add_mkl_shape;
|
||||||
|
GetMklShape(ctx, kInputIndex_Add, &add_mkl_shape, native_format);
|
||||||
|
|
||||||
|
// For native format, we need not to set metadata.
|
||||||
|
if (native_format && ctx->forward_input_to_output_with_shape(
|
||||||
|
kInputIndex_Add, kOutputIndex_Dst,
|
||||||
|
output_tf_shape, &dst_tensor)) {
|
||||||
|
; // Need to do nothing for native format
|
||||||
|
} else if (!native_format && ForwardMklTensorInToOutWithMklShape(
|
||||||
|
ctx, kInputIndex_Add, kOutputIndex_Dst,
|
||||||
|
&dst_tensor, output_mkl_shape, false)) {
|
||||||
|
; // If it's not native format, need to forward and set meta first
|
||||||
|
} else {
|
||||||
|
// If forward is not successful, we should use reorder to copy add
|
||||||
|
// tensor to dst tensor
|
||||||
|
AllocateOutputSetMklShape(ctx, kOutputIndex_Dst, &dst_tensor,
|
||||||
|
output_tf_shape, output_mkl_shape,
|
||||||
|
native_format);
|
||||||
|
auto output_format_tag =
|
||||||
|
MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC);
|
||||||
|
auto add_md =
|
||||||
|
add_mkl_shape.IsMklTensor()
|
||||||
|
? add_mkl_shape.GetMklLayout()
|
||||||
|
: memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
|
||||||
|
auto dst_md =
|
||||||
|
memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
|
||||||
|
|
||||||
|
void* add_buf =
|
||||||
|
static_cast<void*>(const_cast<T*>(add_tensor.flat<T>().data()));
|
||||||
|
void* dst_buf = static_cast<void*>((dst_tensor)->flat<T>().data());
|
||||||
|
|
||||||
|
if (native_format) {
|
||||||
|
// We are simply deep copying the add_tensor to dst_tensor without
|
||||||
|
// changing memory layout, hence using same memory descriptor.
|
||||||
|
add_md = dst_md =
|
||||||
|
memory::desc({add_tensor.NumElements()}, MklDnnType<T>(),
|
||||||
|
mkldnn::memory::format_tag::x);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fuse_add_src_ =
|
||||||
|
MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
|
||||||
|
auto fuse_add_dst_ =
|
||||||
|
MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
|
||||||
|
auto reorder_desc =
|
||||||
|
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
||||||
|
|
||||||
|
CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_,
|
||||||
|
this->cpu_engine_, ctx);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TensorShape dst_tensor_shape({batch, channel});
|
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, output_tf_shape,
|
||||||
MklDnnShape dst_mkl_shape;
|
output_mkl_shape, native_format);
|
||||||
dst_mkl_shape.SetMklTensor(false);
|
|
||||||
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape,
|
|
||||||
dst_mkl_shape, native_format);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// if there's nothing to compute, just return.
|
// if there's nothing to compute, just return.
|
||||||
@ -228,6 +285,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
|
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
|
||||||
} else if (post_op == "Tanh") {
|
} else if (post_op == "Tanh") {
|
||||||
params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
|
params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
|
||||||
|
} else if (post_op == "Add") {
|
||||||
|
params.post_op_params.push_back({"sum", {1.0}});
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, errors::InvalidArgument(
|
ctx, errors::InvalidArgument(
|
||||||
@ -237,10 +296,13 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
bool fuse_add_ = false;
|
||||||
bool transpose_a_;
|
bool transpose_a_;
|
||||||
bool transpose_b_;
|
bool transpose_b_;
|
||||||
std::vector<string> fused_ops_;
|
std::vector<string> fused_ops_;
|
||||||
};
|
const int kInputIndex_Add = 3;
|
||||||
|
const int kOutputIndex_Dst = 0;
|
||||||
|
}; // namespace tensorflow
|
||||||
|
|
||||||
// Register mkl kernels for supported operations and types.
|
// Register mkl kernels for supported operations and types.
|
||||||
#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||||
|
@ -48,8 +48,9 @@ struct MklDnnMatMulFwdParams {
|
|||||||
memory::dims weight_dims;
|
memory::dims weight_dims;
|
||||||
memory::dims bias_dims;
|
memory::dims bias_dims;
|
||||||
memory::dims dst_dims;
|
memory::dims dst_dims;
|
||||||
memory::format_tag src_format;
|
MEMORY_FORMAT src_format;
|
||||||
memory::format_tag weight_format;
|
MEMORY_FORMAT weight_format;
|
||||||
|
MEMORY_FORMAT dst_format;
|
||||||
string dtypes = string("");
|
string dtypes = string("");
|
||||||
struct PostOpParam {
|
struct PostOpParam {
|
||||||
string name;
|
string name;
|
||||||
@ -57,17 +58,18 @@ struct MklDnnMatMulFwdParams {
|
|||||||
};
|
};
|
||||||
std::vector<PostOpParam> post_op_params;
|
std::vector<PostOpParam> post_op_params;
|
||||||
|
|
||||||
MklDnnMatMulFwdParams(
|
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
|
||||||
memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims,
|
memory::dims bias_dims, memory::dims dst_dims,
|
||||||
memory::dims dst_dims,
|
MEMORY_FORMAT src_format = MEMORY_FORMAT::any,
|
||||||
memory::format_tag src_format = memory::format_tag::any,
|
MEMORY_FORMAT weight_format = MEMORY_FORMAT::any,
|
||||||
memory::format_tag weight_format = memory::format_tag::any)
|
MEMORY_FORMAT dst_format = MEMORY_FORMAT::any)
|
||||||
: src_dims(src_dims),
|
: src_dims(src_dims),
|
||||||
weight_dims(weight_dims),
|
weight_dims(weight_dims),
|
||||||
bias_dims(bias_dims),
|
bias_dims(bias_dims),
|
||||||
dst_dims(dst_dims),
|
dst_dims(dst_dims),
|
||||||
src_format(src_format),
|
src_format(src_format),
|
||||||
weight_format(weight_format) {}
|
weight_format(weight_format),
|
||||||
|
dst_format(dst_format) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// With quantization, input, weight, bias, and output can have different types.
|
// With quantization, input, weight, bias, and output can have different types.
|
||||||
@ -184,7 +186,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
|
|
||||||
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
|
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
|
||||||
MklDnnType<Toutput>(),
|
MklDnnType<Toutput>(),
|
||||||
memory::format_tag::any));
|
matmul_fwd_params.dst_format));
|
||||||
|
|
||||||
context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
|
context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
|
||||||
MklDnnType<Tbias>(),
|
MklDnnType<Tbias>(),
|
||||||
@ -236,11 +238,17 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
std::vector<float> scales;
|
std::vector<float> scales;
|
||||||
scales.push_back(post_op_param.param[0]);
|
scales.push_back(post_op_param.param[0]);
|
||||||
post_ops_attr.set_output_scales(0, scales);
|
post_ops_attr.set_output_scales(0, scales);
|
||||||
|
} else if (post_op_param.name == "sum") {
|
||||||
|
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||||
|
float op_scale = post_op_param.param[0];
|
||||||
|
post_ops.append_sum(op_scale);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
DCHECK((post_op_param.name == "relu") ||
|
DCHECK((post_op_param.name == "relu") ||
|
||||||
(post_op_param.name == "relu6") ||
|
(post_op_param.name == "relu6") ||
|
||||||
(post_op_param.name == "elu") ||
|
(post_op_param.name == "elu") ||
|
||||||
(post_op_param.name == "tanh") ||
|
(post_op_param.name == "tanh") ||
|
||||||
|
(post_op_param.name == "sum") ||
|
||||||
(post_op_param.name == "output_scale"));
|
(post_op_param.name == "output_scale"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -340,6 +348,10 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
key_creator.AddAsKey(post_op_param.param[0]);
|
key_creator.AddAsKey(post_op_param.param[0]);
|
||||||
key_creator.AddAsKey(post_op_param.param[1]);
|
key_creator.AddAsKey(post_op_param.param[1]);
|
||||||
key_creator.AddAsKey(post_op_param.param[2]);
|
key_creator.AddAsKey(post_op_param.param[2]);
|
||||||
|
} else if (post_op_param.name == "sum") {
|
||||||
|
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||||
|
key_creator.AddAsKey(post_op_param.name);
|
||||||
|
key_creator.AddAsKey(post_op_param.param[0]);
|
||||||
} else if (post_op_param.name == "output_scale") {
|
} else if (post_op_param.name == "output_scale") {
|
||||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||||
key_creator.AddAsKey(post_op_param.name);
|
key_creator.AddAsKey(post_op_param.name);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user