Merge pull request #45358 from Intel-tensorflow:yanzhang/matmul_biasadd_add_fusion
PiperOrigin-RevId: 347408532 Change-Id: I6e8b12dfef056f095449fd70ce387b79d3e8b4d7
This commit is contained in:
commit
cd052fa5f0
@ -446,6 +446,84 @@ TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = ops::Placeholder::Shape({4, 32});
|
||||
auto input_shape_add = ops::Placeholder::Shape({4, 8});
|
||||
auto filter_shape = ops::Placeholder::Shape({32, 8});
|
||||
auto bias_shape = ops::Placeholder::Shape({8});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto input_add =
|
||||
Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
|
||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
|
||||
|
||||
ops::Identity(fetch, add);
|
||||
|
||||
auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(input_shape.shape_.dim_sizes()));
|
||||
auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(input_shape_add.shape_.dim_sizes()));
|
||||
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(filter_shape.shape_.dim_sizes()));
|
||||
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(bias_shape.shape_.dim_sizes()));
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_tensor},
|
||||
{"filter", filter_tensor},
|
||||
{"bias", bias_tensor},
|
||||
{"input_add", input_add_tensor}};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE);
|
||||
GraphDef output;
|
||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
auto fetch_node_name = "add";
|
||||
if (node.name() == fetch_node_name) {
|
||||
EXPECT_EQ("_FusedMatMul", node.op());
|
||||
EXPECT_EQ("input", node.input(0));
|
||||
EXPECT_EQ("filter", node.input(1));
|
||||
|
||||
EXPECT_EQ(2, node.attr().at("num_args").i());
|
||||
EXPECT_EQ("bias", node.input(2));
|
||||
EXPECT_EQ("input_add", node.input(3));
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
EXPECT_EQ(2, fused_ops.size());
|
||||
EXPECT_EQ("BiasAdd", fused_ops[0]);
|
||||
EXPECT_EQ("Add", fused_ops[1]);
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(1, found);
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
EXPECT_EQ(1, tensors_expected.size());
|
||||
EXPECT_EQ(1, tensors.size());
|
||||
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
} // namespace grappler
|
||||
|
@ -1283,28 +1283,36 @@ Status AddFusedContractionNode(RemapperContext* ctx,
|
||||
const NodeDef& contraction = graph->node(matched.contraction);
|
||||
const NodeDef& bias_add = graph->node(matched.bias_add);
|
||||
|
||||
// MKL version only support fusion for Conv2D
|
||||
DCHECK(IsConv2D(contraction));
|
||||
// MKL version only support fusion for Conv2D and MatMul
|
||||
DCHECK(IsConv2D(contraction) || IsMatMul(contraction));
|
||||
|
||||
NodeDef fused_conv2d;
|
||||
NodeDef contraction_node;
|
||||
const NodeDef& add = graph->node(matched.add);
|
||||
fused_conv2d.set_name(add.name());
|
||||
fused_conv2d.set_op(kFusedConv2D);
|
||||
fused_conv2d.set_device(contraction.device());
|
||||
fused_conv2d.add_input(contraction.input(0)); // 0: input
|
||||
fused_conv2d.add_input(contraction.input(1)); // 1: filter
|
||||
fused_conv2d.add_input(bias_add.input(1)); // 2: bias
|
||||
contraction_node.set_name(add.name());
|
||||
contraction_node.set_device(contraction.device());
|
||||
contraction_node.add_input(
|
||||
contraction.input(0)); // 0: input(conv) / a (matmul)
|
||||
contraction_node.add_input(
|
||||
contraction.input(1)); // 1: filter(conv) / b (matmul)
|
||||
contraction_node.add_input(bias_add.input(1)); // 2: bias
|
||||
|
||||
// Add OP has two inputs, one is conv+bias pattern matched previously,
|
||||
// the other input to add is fused here.
|
||||
fused_conv2d.add_input(add.input(1 - matched.port_id));
|
||||
// Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
|
||||
// previously, the other input to add is fused here.
|
||||
contraction_node.add_input(add.input(1 - matched.port_id));
|
||||
|
||||
CopyConv2DAttributes(contraction, &fused_conv2d);
|
||||
SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add"}, 2);
|
||||
if (IsConv2D(contraction)) {
|
||||
contraction_node.set_op(kFusedConv2D);
|
||||
CopyConv2DAttributes(contraction, &contraction_node);
|
||||
} else if (IsMatMul(contraction)) {
|
||||
contraction_node.set_op(kFusedMatMul);
|
||||
CopyMatMulAttributes(contraction, &contraction_node);
|
||||
}
|
||||
|
||||
SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2);
|
||||
|
||||
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
||||
Status status;
|
||||
mutation->AddNode(std::move(fused_conv2d), &status);
|
||||
mutation->AddNode(std::move(contraction_node), &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
TF_RETURN_IF_ERROR(mutation->Apply());
|
||||
|
||||
@ -1621,19 +1629,25 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
||||
}
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
bool IsConv2DWithAdd(const RemapperContext& ctx, int node_index) {
|
||||
bool IsConv2DOrMatMul(const NodeDef& node) {
|
||||
return IsConv2D(node) || IsMatMul(node);
|
||||
}
|
||||
|
||||
bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
|
||||
const auto* node_view = ctx.graph_view.GetNode(node_index);
|
||||
const auto* node_def = node_view->node();
|
||||
|
||||
// Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
|
||||
// MatMul + Add or MatMul + BiasAdd + Add fusion.
|
||||
auto is_supported_add_input = [](const auto* node_view) -> bool {
|
||||
if (IsConv2D(*node_view->node())) return true;
|
||||
// Currently only support Conv2D and MatMul
|
||||
if (IsConv2DOrMatMul(*node_view->node())) return true;
|
||||
if (IsBiasAdd(*node_view->node())) {
|
||||
if (node_view->NumRegularFanins() < 2) return false;
|
||||
const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
|
||||
const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
|
||||
return IsConv2D(*bias_add_fanin_0.node_view()->node()) ||
|
||||
IsConv2D(*bias_add_fanin_1.node_view()->node());
|
||||
return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
|
||||
IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node());
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@ -1739,7 +1753,7 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() ||
|
||||
IsConv2DWithAdd(ctx, node_index);
|
||||
IsContractionWithAdd(ctx, node_index);
|
||||
#else
|
||||
return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
|
||||
is_batch_norm_fusion_candidate();
|
||||
|
@ -955,6 +955,71 @@ TEST_F(FilterCacheTest, Conv2DFilterCacheTest) {
|
||||
// Testing fusion of MatMul and BiasAdd
|
||||
template <typename T>
|
||||
class MklFusedMatMulOpTest : public OpsTestBase {
|
||||
private:
|
||||
void RunMklFusedMatMulOp(const Tensor& input, const Tensor& weight,
|
||||
const std::vector<Tensor>& args,
|
||||
const std::vector<string>& fused_ops,
|
||||
Tensor* output) {
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
const int num_args = args.size();
|
||||
if (!NativeFormatEnabled()) {
|
||||
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(num_args, dtype))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(num_args, DT_UINT8))
|
||||
.Attr("T", dtype)
|
||||
.Attr("transpose_a", false)
|
||||
.Attr("transpose_b", false)
|
||||
.Attr("num_args", num_args)
|
||||
.Attr("fused_ops", fused_ops)
|
||||
.Attr("epsilon", 0.0001)
|
||||
.Attr("_kernel", "MklLayoutDependentOp")
|
||||
.Finalize(node_def()));
|
||||
} else {
|
||||
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(num_args, dtype))
|
||||
.Attr("T", dtype)
|
||||
.Attr("transpose_a", false)
|
||||
.Attr("transpose_b", false)
|
||||
.Attr("num_args", num_args)
|
||||
.Attr("fused_ops", fused_ops)
|
||||
.Attr("epsilon", 0.0001)
|
||||
.Attr("_kernel", "MklNameChangeOp")
|
||||
.Finalize(node_def()));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(InitOp());
|
||||
|
||||
AddInputFromArray<T>(input.shape(), input.flat<T>());
|
||||
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
|
||||
for (const Tensor& arg : args)
|
||||
AddInputFromArray<T>(arg.shape(), arg.flat<T>());
|
||||
if (!NativeFormatEnabled()) {
|
||||
// Add MKL meta input for input, filter and bias.
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
for (const Tensor& arg : args)
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
const Tensor& output_tensor = *GetOutput(0);
|
||||
if (!NativeFormatEnabled()) {
|
||||
const Tensor& output_meta_tensor = *GetOutput(1);
|
||||
CommonTestUtilities<T> test_util;
|
||||
test_util.PerformConversion(dtype, output_tensor, output_meta_tensor,
|
||||
output);
|
||||
} else {
|
||||
*output = output_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void VerifyFusedMatMul(const int kBatch, const int kInputChannel,
|
||||
const int kOutputChannel,
|
||||
@ -1002,70 +1067,24 @@ class MklFusedMatMulOpTest : public OpsTestBase {
|
||||
next_op = ops::Tanh(root.WithOpName(last_op), next_op);
|
||||
}
|
||||
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
|
||||
fused_ops.end()) {
|
||||
last_op = "with_add";
|
||||
next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op);
|
||||
}
|
||||
|
||||
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
|
||||
};
|
||||
|
||||
const FusedGraphRunner run_fused =
|
||||
[this](const Tensor& input, const Tensor& weight, const Tensor& bias,
|
||||
const std::vector<string>& fused_ops, Tensor* output) {
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
const int num_args = 1;
|
||||
|
||||
if (!NativeFormatEnabled()) {
|
||||
TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul")
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(num_args, dtype))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(num_args, DT_UINT8))
|
||||
.Attr("T", dtype)
|
||||
.Attr("transpose_a", false)
|
||||
.Attr("transpose_b", false)
|
||||
.Attr("num_args", num_args)
|
||||
.Attr("fused_ops", fused_ops)
|
||||
.Attr("epsilon", 0.0001)
|
||||
.Attr("_kernel", "MklLayoutDependentOp")
|
||||
.Finalize(node_def()));
|
||||
} else {
|
||||
TF_EXPECT_OK(
|
||||
NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(num_args, dtype))
|
||||
.Attr("T", dtype)
|
||||
.Attr("transpose_a", false)
|
||||
.Attr("transpose_b", false)
|
||||
.Attr("num_args", num_args)
|
||||
.Attr("fused_ops", fused_ops)
|
||||
.Attr("epsilon", 0.0001)
|
||||
.Attr("_kernel", "MklNameChangeOp")
|
||||
.Finalize(node_def()));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(InitOp());
|
||||
|
||||
AddInputFromArray<T>(input.shape(), input.flat<T>());
|
||||
AddInputFromArray<T>(weight.shape(), weight.flat<T>());
|
||||
AddInputFromArray<T>(bias.shape(), bias.flat<T>());
|
||||
if (!NativeFormatEnabled()) {
|
||||
// Add MKL meta input for input, filter and bias.
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
const Tensor& output_tensor = *GetOutput(0);
|
||||
if (!NativeFormatEnabled()) {
|
||||
const Tensor& output_meta_tensor = *GetOutput(1);
|
||||
CommonTestUtilities<T> test_util;
|
||||
test_util.PerformConversion(dtype, output_tensor,
|
||||
output_meta_tensor, output);
|
||||
} else {
|
||||
*output = output_tensor;
|
||||
std::vector<Tensor> fused_input = {bias};
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
|
||||
fused_ops.end()) {
|
||||
fused_input.push_back(input);
|
||||
}
|
||||
RunMklFusedMatMulOp(input, weight, fused_input, fused_ops, output);
|
||||
};
|
||||
|
||||
CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch,
|
||||
@ -1120,12 +1139,22 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndTanh) {
|
||||
{"BiasAdd", "Tanh"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) {
|
||||
const int batch = 3;
|
||||
const int input_channel = 4;
|
||||
const int output_channel = 4;
|
||||
|
||||
this->VerifyFusedMatMul(batch, input_channel, output_channel,
|
||||
{"BiasAdd", "Add"});
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, //
|
||||
WithBias, //
|
||||
WithBiasAndRelu, //
|
||||
WithBiasAndRelu6, //
|
||||
WithBiasAndElu, //
|
||||
WithBiasAndTanh);
|
||||
WithBiasAndTanh, //
|
||||
WithBiasAndAdd);
|
||||
|
||||
using MklFusedMatMulDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,
|
||||
|
@ -45,6 +45,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
ctx, fused_ops_[0] == "BiasAdd",
|
||||
errors::InvalidArgument(
|
||||
"The 1st post-argument of MklFusedMatMul must be BiasAdd."));
|
||||
if (fused_ops_.size() > 1 && fused_ops_[1] == "Add") fuse_add_ = true;
|
||||
OP_REQUIRES(
|
||||
ctx, transpose_a_ == false,
|
||||
errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
|
||||
@ -114,7 +115,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
// 2. var, keep the original format to avoid reordering.
|
||||
MklDnnMatMulFwdParams matmul_params(
|
||||
src_dims, weight_dims, bias_dims, dst_dims, src_format,
|
||||
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format);
|
||||
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format,
|
||||
MEMORY_FORMAT::nc);
|
||||
|
||||
// Extend the basic parameters for data types and fusions.
|
||||
ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
|
||||
@ -126,15 +128,70 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd =
|
||||
matmul_prim->GetPrimitiveDesc();
|
||||
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims,
|
||||
MKL_TENSOR_FORMAT_NC, &dst_tensor);
|
||||
// The output shape of MatMul is same both for MKL and TF version.
|
||||
// They are all NC format, no matter what's the format of input.
|
||||
// And the shape of AddOp is also the same with output's shape.
|
||||
auto dst_pd = matmul_pd->PRIMITIVE_DESC_DST;
|
||||
|
||||
MklDnnShape output_mkl_shape;
|
||||
output_mkl_shape.SetMklTensor(false);
|
||||
|
||||
TensorShape output_tf_shape({batch, channel});
|
||||
|
||||
if (fuse_add_) {
|
||||
const Tensor& add_tensor = MklGetInput(ctx, kInputIndex_Add);
|
||||
MklDnnShape add_mkl_shape;
|
||||
GetMklShape(ctx, kInputIndex_Add, &add_mkl_shape, native_format);
|
||||
|
||||
// For native format, we need not to set metadata.
|
||||
if (native_format && ctx->forward_input_to_output_with_shape(
|
||||
kInputIndex_Add, kOutputIndex_Dst,
|
||||
output_tf_shape, &dst_tensor)) {
|
||||
; // Need to do nothing for native format
|
||||
} else if (!native_format && ForwardMklTensorInToOutWithMklShape(
|
||||
ctx, kInputIndex_Add, kOutputIndex_Dst,
|
||||
&dst_tensor, output_mkl_shape, false)) {
|
||||
; // If it's not native format, need to forward and set meta first
|
||||
} else {
|
||||
// If forward is not successful, we should use reorder to copy add
|
||||
// tensor to dst tensor
|
||||
AllocateOutputSetMklShape(ctx, kOutputIndex_Dst, &dst_tensor,
|
||||
output_tf_shape, output_mkl_shape,
|
||||
native_format);
|
||||
auto output_format_tag =
|
||||
MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC);
|
||||
auto add_md =
|
||||
add_mkl_shape.IsMklTensor()
|
||||
? add_mkl_shape.GetMklLayout()
|
||||
: memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
|
||||
auto dst_md =
|
||||
memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
|
||||
|
||||
void* add_buf =
|
||||
static_cast<void*>(const_cast<T*>(add_tensor.flat<T>().data()));
|
||||
void* dst_buf = static_cast<void*>((dst_tensor)->flat<T>().data());
|
||||
|
||||
if (native_format) {
|
||||
// We are simply deep copying the add_tensor to dst_tensor without
|
||||
// changing memory layout, hence using same memory descriptor.
|
||||
add_md = dst_md =
|
||||
memory::desc({add_tensor.NumElements()}, MklDnnType<T>(),
|
||||
mkldnn::memory::format_tag::x);
|
||||
}
|
||||
|
||||
auto fuse_add_src_ =
|
||||
MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
|
||||
auto fuse_add_dst_ =
|
||||
MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
|
||||
auto reorder_desc =
|
||||
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
||||
|
||||
CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_,
|
||||
this->cpu_engine_, ctx);
|
||||
}
|
||||
} else {
|
||||
TensorShape dst_tensor_shape({batch, channel});
|
||||
MklDnnShape dst_mkl_shape;
|
||||
dst_mkl_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape,
|
||||
dst_mkl_shape, native_format);
|
||||
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, output_tf_shape,
|
||||
output_mkl_shape, native_format);
|
||||
}
|
||||
|
||||
// if there's nothing to compute, just return.
|
||||
@ -228,6 +285,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
|
||||
} else if (post_op == "Tanh") {
|
||||
params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
|
||||
} else if (post_op == "Add") {
|
||||
params.post_op_params.push_back({"sum", {1.0}});
|
||||
} else {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, errors::InvalidArgument(
|
||||
@ -237,10 +296,13 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
}
|
||||
|
||||
private:
|
||||
bool fuse_add_ = false;
|
||||
bool transpose_a_;
|
||||
bool transpose_b_;
|
||||
std::vector<string> fused_ops_;
|
||||
};
|
||||
const int kInputIndex_Add = 3;
|
||||
const int kOutputIndex_Dst = 0;
|
||||
}; // namespace tensorflow
|
||||
|
||||
// Register mkl kernels for supported operations and types.
|
||||
#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||
|
@ -48,8 +48,9 @@ struct MklDnnMatMulFwdParams {
|
||||
memory::dims weight_dims;
|
||||
memory::dims bias_dims;
|
||||
memory::dims dst_dims;
|
||||
memory::format_tag src_format;
|
||||
memory::format_tag weight_format;
|
||||
MEMORY_FORMAT src_format;
|
||||
MEMORY_FORMAT weight_format;
|
||||
MEMORY_FORMAT dst_format;
|
||||
string dtypes = string("");
|
||||
struct PostOpParam {
|
||||
string name;
|
||||
@ -57,17 +58,18 @@ struct MklDnnMatMulFwdParams {
|
||||
};
|
||||
std::vector<PostOpParam> post_op_params;
|
||||
|
||||
MklDnnMatMulFwdParams(
|
||||
memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims,
|
||||
memory::dims dst_dims,
|
||||
memory::format_tag src_format = memory::format_tag::any,
|
||||
memory::format_tag weight_format = memory::format_tag::any)
|
||||
MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims,
|
||||
memory::dims bias_dims, memory::dims dst_dims,
|
||||
MEMORY_FORMAT src_format = MEMORY_FORMAT::any,
|
||||
MEMORY_FORMAT weight_format = MEMORY_FORMAT::any,
|
||||
MEMORY_FORMAT dst_format = MEMORY_FORMAT::any)
|
||||
: src_dims(src_dims),
|
||||
weight_dims(weight_dims),
|
||||
bias_dims(bias_dims),
|
||||
dst_dims(dst_dims),
|
||||
src_format(src_format),
|
||||
weight_format(weight_format) {}
|
||||
weight_format(weight_format),
|
||||
dst_format(dst_format) {}
|
||||
};
|
||||
|
||||
// With quantization, input, weight, bias, and output can have different types.
|
||||
@ -184,7 +186,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
|
||||
context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
|
||||
MklDnnType<Toutput>(),
|
||||
memory::format_tag::any));
|
||||
matmul_fwd_params.dst_format));
|
||||
|
||||
context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
|
||||
MklDnnType<Tbias>(),
|
||||
@ -236,11 +238,17 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
std::vector<float> scales;
|
||||
scales.push_back(post_op_param.param[0]);
|
||||
post_ops_attr.set_output_scales(0, scales);
|
||||
} else if (post_op_param.name == "sum") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
float op_scale = post_op_param.param[0];
|
||||
post_ops.append_sum(op_scale);
|
||||
|
||||
} else {
|
||||
DCHECK((post_op_param.name == "relu") ||
|
||||
(post_op_param.name == "relu6") ||
|
||||
(post_op_param.name == "elu") ||
|
||||
(post_op_param.name == "tanh") ||
|
||||
(post_op_param.name == "sum") ||
|
||||
(post_op_param.name == "output_scale"));
|
||||
}
|
||||
}
|
||||
@ -340,6 +348,10 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
key_creator.AddAsKey(post_op_param.param[0]);
|
||||
key_creator.AddAsKey(post_op_param.param[1]);
|
||||
key_creator.AddAsKey(post_op_param.param[2]);
|
||||
} else if (post_op_param.name == "sum") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
key_creator.AddAsKey(post_op_param.name);
|
||||
key_creator.AddAsKey(post_op_param.param[0]);
|
||||
} else if (post_op_param.name == "output_scale") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
key_creator.AddAsKey(post_op_param.name);
|
||||
|
Loading…
x
Reference in New Issue
Block a user