Merge pull request #45358 from Intel-tensorflow:yanzhang/matmul_biasadd_add_fusion

PiperOrigin-RevId: 347408532
Change-Id: I6e8b12dfef056f095449fd70ce387b79d3e8b4d7
This commit is contained in:
TensorFlower Gardener 2020-12-14 10:05:53 -08:00
commit cd052fa5f0
5 changed files with 292 additions and 97 deletions

View File

@ -446,6 +446,84 @@ TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
}
}
}
TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({4, 32});
auto input_shape_add = ops::Placeholder::Shape({4, 8});
auto filter_shape = ops::Placeholder::Shape({32, 8});
auto bias_shape = ops::Placeholder::Shape({8});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_add =
Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
auto fetch = s.WithOpName("fetch");
auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
ops::Identity(fetch, add);
auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape.shape_.dim_sizes()));
auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape_add.shape_.dim_sizes()));
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(filter_shape.shape_.dim_sizes()));
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(bias_shape.shape_.dim_sizes()));
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_tensor},
{"filter", filter_tensor},
{"bias", bias_tensor},
{"input_add", input_add_tensor}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
auto fetch_node_name = "add";
if (node.name() == fetch_node_name) {
EXPECT_EQ("_FusedMatMul", node.op());
EXPECT_EQ("input", node.input(0));
EXPECT_EQ("filter", node.input(1));
EXPECT_EQ(2, node.attr().at("num_args").i());
EXPECT_EQ("bias", node.input(2));
EXPECT_EQ("input_add", node.input(3));
const auto fused_ops = node.attr().at("fused_ops").list().s();
EXPECT_EQ(2, fused_ops.size());
EXPECT_EQ("BiasAdd", fused_ops[0]);
EXPECT_EQ("Add", fused_ops[1]);
found++;
}
}
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
}
#endif // ENABLE_MKLDNN_V1
} // namespace grappler

View File

@ -1283,28 +1283,36 @@ Status AddFusedContractionNode(RemapperContext* ctx,
const NodeDef& contraction = graph->node(matched.contraction);
const NodeDef& bias_add = graph->node(matched.bias_add);
// MKL version only support fusion for Conv2D
DCHECK(IsConv2D(contraction));
// MKL version only support fusion for Conv2D and MatMul
DCHECK(IsConv2D(contraction) || IsMatMul(contraction));
NodeDef fused_conv2d;
NodeDef contraction_node;
const NodeDef& add = graph->node(matched.add);
fused_conv2d.set_name(add.name());
fused_conv2d.set_op(kFusedConv2D);
fused_conv2d.set_device(contraction.device());
fused_conv2d.add_input(contraction.input(0)); // 0: input
fused_conv2d.add_input(contraction.input(1)); // 1: filter
fused_conv2d.add_input(bias_add.input(1)); // 2: bias
contraction_node.set_name(add.name());
contraction_node.set_device(contraction.device());
contraction_node.add_input(
contraction.input(0)); // 0: input(conv) / a (matmul)
contraction_node.add_input(
contraction.input(1)); // 1: filter(conv) / b (matmul)
contraction_node.add_input(bias_add.input(1)); // 2: bias
// Add OP has two inputs, one is conv+bias pattern matched previously,
// the other input to add is fused here.
fused_conv2d.add_input(add.input(1 - matched.port_id));
// Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
// previously, the other input to add is fused here.
contraction_node.add_input(add.input(1 - matched.port_id));
CopyConv2DAttributes(contraction, &fused_conv2d);
SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add"}, 2);
if (IsConv2D(contraction)) {
contraction_node.set_op(kFusedConv2D);
CopyConv2DAttributes(contraction, &contraction_node);
} else if (IsMatMul(contraction)) {
contraction_node.set_op(kFusedMatMul);
CopyMatMulAttributes(contraction, &contraction_node);
}
SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2);
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
mutation->AddNode(std::move(fused_conv2d), &status);
mutation->AddNode(std::move(contraction_node), &status);
TF_RETURN_IF_ERROR(status);
TF_RETURN_IF_ERROR(mutation->Apply());
@ -1621,19 +1629,25 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
}
#ifdef INTEL_MKL
bool IsConv2DWithAdd(const RemapperContext& ctx, int node_index) {
bool IsConv2DOrMatMul(const NodeDef& node) {
return IsConv2D(node) || IsMatMul(node);
}
bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
const auto* node_view = ctx.graph_view.GetNode(node_index);
const auto* node_def = node_view->node();
// Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
// MatMul + Add or MatMul + BiasAdd + Add fusion.
auto is_supported_add_input = [](const auto* node_view) -> bool {
if (IsConv2D(*node_view->node())) return true;
// Currently only support Conv2D and MatMul
if (IsConv2DOrMatMul(*node_view->node())) return true;
if (IsBiasAdd(*node_view->node())) {
if (node_view->NumRegularFanins() < 2) return false;
const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
return IsConv2D(*bias_add_fanin_0.node_view()->node()) ||
IsConv2D(*bias_add_fanin_1.node_view()->node());
return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node());
}
return false;
};
@ -1739,7 +1753,7 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
#ifdef INTEL_MKL
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
IsConv2DWithAdd(ctx, node_index);
IsContractionWithAdd(ctx, node_index);
#else
return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
is_batch_norm_fusion_candidate();

View File

@ -955,6 +955,71 @@ TEST_F(FilterCacheTest, Conv2DFilterCacheTest) {
// Testing fusion of MatMul and BiasAdd
template <typename T>
class MklFusedMatMulOpTest : public OpsTestBase {
private:
void RunMklFusedMatMulOp(const Tensor& input, const Tensor& weight,
const std::vector<Tensor>& args,
const std::vector<string>& fused_ops,
Tensor* output) {
DataType dtype = DataTypeToEnum<T>::v();
const int num_args = args.size();
if (!NativeFormatEnabled()) {
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(num_args, DT_UINT8))
.Attr("T", dtype)
.Attr("transpose_a", false)
.Attr("transpose_b", false)
.Attr("num_args", num_args)
.Attr("fused_ops", fused_ops)
.Attr("epsilon", 0.0001)
.Attr("_kernel", "MklLayoutDependentOp")
.Finalize(node_def()));
} else {
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Attr("T", dtype)
.Attr("transpose_a", false)
.Attr("transpose_b", false)
.Attr("num_args", num_args)
.Attr("fused_ops", fused_ops)
.Attr("epsilon", 0.0001)
.Attr("_kernel", "MklNameChangeOp")
.Finalize(node_def()));
}
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(input.shape(), input.flat<T>());
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
for (const Tensor& arg : args)
AddInputFromArray<T>(arg.shape(), arg.flat<T>());
if (!NativeFormatEnabled()) {
// Add MKL meta input for input, filter and bias.
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
for (const Tensor& arg : args)
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
}
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_tensor = *GetOutput(0);
if (!NativeFormatEnabled()) {
const Tensor& output_meta_tensor = *GetOutput(1);
CommonTestUtilities<T> test_util;
test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
output);
} else {
*output = output_tensor;
}
}
protected:
void VerifyFusedMatMul(const int kBatch, const int kInputChannel,
const int kOutputChannel,
@ -1002,70 +1067,24 @@ class MklFusedMatMulOpTest : public OpsTestBase {
next_op = ops::Tanh(root.WithOpName(last_op), next_op);
}
if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
fused_ops.end()) {
last_op = "with_add";
next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op);
}
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
};
const FusedGraphRunner run_fused =
[this](const Tensor& input, const Tensor& weight, const Tensor& bias,
const std::vector<string>& fused_ops, Tensor* output) {
DataType dtype = DataTypeToEnum<T>::v();
const int num_args = 1;
if (!NativeFormatEnabled()) {
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(num_args, DT_UINT8))
.Attr("T", dtype)
.Attr("transpose_a", false)
.Attr("transpose_b", false)
.Attr("num_args", num_args)
.Attr("fused_ops", fused_ops)
.Attr("epsilon", 0.0001)
.Attr("_kernel", "MklLayoutDependentOp")
.Finalize(node_def()));
} else {
TF_EXPECT_OK(
NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Input(FakeInput(num_args, dtype))
.Attr("T", dtype)
.Attr("transpose_a", false)
.Attr("transpose_b", false)
.Attr("num_args", num_args)
.Attr("fused_ops", fused_ops)
.Attr("epsilon", 0.0001)
.Attr("_kernel", "MklNameChangeOp")
.Finalize(node_def()));
}
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(input.shape(), input.flat<T>());
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
AddInputFromArray<T>(bias.shape(), bias.flat<T>());
if (!NativeFormatEnabled()) {
// Add MKL meta input for input, filter and bias.
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
}
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_tensor = *GetOutput(0);
if (!NativeFormatEnabled()) {
const Tensor& output_meta_tensor = *GetOutput(1);
CommonTestUtilities<T> test_util;
test_util.PerformConversion(dtype, output_tensor,
output_meta_tensor, output);
} else {
*output = output_tensor;
std::vector<Tensor> fused_input = {bias};
if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
fused_ops.end()) {
fused_input.push_back(input);
}
RunMklFusedMatMulOp(input, weight, fused_input, fused_ops, output);
};
CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch,
@ -1120,12 +1139,22 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndTanh) {
{"BiasAdd", "Tanh"});
}
TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) {
const int batch = 3;
const int input_channel = 4;
const int output_channel = 4;
this->VerifyFusedMatMul(batch, input_channel, output_channel,
{"BiasAdd", "Add"});
}
REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, //
WithBias, //
WithBiasAndRelu, //
WithBiasAndRelu6, //
WithBiasAndElu, //
WithBiasAndTanh);
WithBiasAndTanh, //
WithBiasAndAdd);
using MklFusedMatMulDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,

View File

@ -45,6 +45,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
ctx, fused_ops_[0] == "BiasAdd",
errors::InvalidArgument(
"The 1st post-argument of MklFusedMatMul must be BiasAdd."));
if (fused_ops_.size() > 1 && fused_ops_[1] == "Add") fuse_add_ = true;
OP_REQUIRES(
ctx, transpose_a_ == false,
errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
@ -114,7 +115,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
// 2. var, keep the original format to avoid reordering.
MklDnnMatMulFwdParams matmul_params(
src_dims, weight_dims, bias_dims, dst_dims, src_format,
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format);
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format,
MEMORY_FORMAT::nc);
// Extend the basic parameters for data types and fusions.
ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
@ -126,15 +128,70 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd =
matmul_prim->GetPrimitiveDesc();
if (src_mkl_shape.IsMklTensor()) {
this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims,
MKL_TENSOR_FORMAT_NC, &dst_tensor);
// The output shape of MatMul is same both for MKL and TF version.
// They are all NC format, no matter what's the format of input.
// And the shape of AddOp is also the same with output's shape.
auto dst_pd = matmul_pd->PRIMITIVE_DESC_DST;
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);
TensorShape output_tf_shape({batch, channel});
if (fuse_add_) {
const Tensor& add_tensor = MklGetInput(ctx, kInputIndex_Add);
MklDnnShape add_mkl_shape;
GetMklShape(ctx, kInputIndex_Add, &add_mkl_shape, native_format);
// For native format, we need not to set metadata.
if (native_format && ctx->forward_input_to_output_with_shape(
kInputIndex_Add, kOutputIndex_Dst,
output_tf_shape, &dst_tensor)) {
; // Need to do nothing for native format
} else if (!native_format && ForwardMklTensorInToOutWithMklShape(
ctx, kInputIndex_Add, kOutputIndex_Dst,
&dst_tensor, output_mkl_shape, false)) {
; // If it's not native format, need to forward and set meta first
} else {
// If forward is not successful, we should use reorder to copy add
// tensor to dst tensor
AllocateOutputSetMklShape(ctx, kOutputIndex_Dst, &dst_tensor,
output_tf_shape, output_mkl_shape,
native_format);
auto output_format_tag =
MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC);
auto add_md =
add_mkl_shape.IsMklTensor()
? add_mkl_shape.GetMklLayout()
: memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
auto dst_md =
memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
void* add_buf =
static_cast<void*>(const_cast<T*>(add_tensor.flat<T>().data()));
void* dst_buf = static_cast<void*>((dst_tensor)->flat<T>().data());
if (native_format) {
// We are simply deep copying the add_tensor to dst_tensor without
// changing memory layout, hence using same memory descriptor.
add_md = dst_md =
memory::desc({add_tensor.NumElements()}, MklDnnType<T>(),
mkldnn::memory::format_tag::x);
}
auto fuse_add_src_ =
MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
auto fuse_add_dst_ =
MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
auto reorder_desc =
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_,
this->cpu_engine_, ctx);
}
} else {
TensorShape dst_tensor_shape({batch, channel});
MklDnnShape dst_mkl_shape;
dst_mkl_shape.SetMklTensor(false);
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape,
dst_mkl_shape, native_format);
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, output_tf_shape,
output_mkl_shape, native_format);
}
// if there's nothing to compute, just return.
@ -228,6 +285,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
} else if (post_op == "Tanh") {
params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
} else if (post_op == "Add") {
params.post_op_params.push_back({"sum", {1.0}});
} else {
OP_REQUIRES_OK(
ctx, errors::InvalidArgument(
@ -237,10 +296,13 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
}
private:
bool fuse_add_ = false;
bool transpose_a_;
bool transpose_b_;
std::vector<string> fused_ops_;
};
const int kInputIndex_Add = 3;
const int kOutputIndex_Dst = 0;
}; // namespace tensorflow
// Register mkl kernels for supported operations and types.
#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \

View File

@ -48,8 +48,9 @@ struct MklDnnMatMulFwdParams {
memory::dims weight_dims;
memory::dims bias_dims;
memory::dims dst_dims;
memory::format_tag src_format;
memory::format_tag weight_format;
MEMORY_FORMAT src_format;
MEMORY_FORMAT weight_format;
MEMORY_FORMAT dst_format;
string dtypes = string("");
struct PostOpParam {
string name;
@ -57,17 +58,18 @@ struct MklDnnMatMulFwdParams {
};
std::vector<PostOpParam> post_op_params;
MklDnnMatMulFwdParams(
memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims,
memory::dims dst_dims,
memory::format_tag src_format = memory::format_tag::any,
memory::format_tag weight_format = memory::format_tag::any)
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
memory::dims bias_dims, memory::dims dst_dims,
MEMORY_FORMAT src_format = MEMORY_FORMAT::any,
MEMORY_FORMAT weight_format = MEMORY_FORMAT::any,
MEMORY_FORMAT dst_format = MEMORY_FORMAT::any)
: src_dims(src_dims),
weight_dims(weight_dims),
bias_dims(bias_dims),
dst_dims(dst_dims),
src_format(src_format),
weight_format(weight_format) {}
weight_format(weight_format),
dst_format(dst_format) {}
};
// With quantization, input, weight, bias, and output can have different types.
@ -184,7 +186,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
MklDnnType<Toutput>(),
memory::format_tag::any));
matmul_fwd_params.dst_format));
context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
MklDnnType<Tbias>(),
@ -236,11 +238,17 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
std::vector<float> scales;
scales.push_back(post_op_param.param[0]);
post_ops_attr.set_output_scales(0, scales);
} else if (post_op_param.name == "sum") {
DCHECK_EQ(post_op_param.param.size(), 1);
float op_scale = post_op_param.param[0];
post_ops.append_sum(op_scale);
} else {
DCHECK((post_op_param.name == "relu") ||
(post_op_param.name == "relu6") ||
(post_op_param.name == "elu") ||
(post_op_param.name == "tanh") ||
(post_op_param.name == "sum") ||
(post_op_param.name == "output_scale"));
}
}
@ -340,6 +348,10 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey(post_op_param.param[0]);
key_creator.AddAsKey(post_op_param.param[1]);
key_creator.AddAsKey(post_op_param.param[2]);
} else if (post_op_param.name == "sum") {
DCHECK_EQ(post_op_param.param.size(), 1);
key_creator.AddAsKey(post_op_param.name);
key_creator.AddAsKey(post_op_param.param[0]);
} else if (post_op_param.name == "output_scale") {
DCHECK_EQ(post_op_param.param.size(), 1);
key_creator.AddAsKey(post_op_param.name);