diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc index f8574a4e0d3..e9270ff4e54 100644 --- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc @@ -446,6 +446,84 @@ TEST_F(MklRemapperTest, FuseBatchNormWithRelu) { } } } + +TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) { + using ::tensorflow::ops::Placeholder; + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto input_shape = ops::Placeholder::Shape({4, 32}); + auto input_shape_add = ops::Placeholder::Shape({4, 8}); + auto filter_shape = ops::Placeholder::Shape({32, 8}); + auto bias_shape = ops::Placeholder::Shape({8}); + + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); + auto input_add = + Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add); + auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); + auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); + + auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter); + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias); + + auto fetch = s.WithOpName("fetch"); + auto add = ops::Add(s.WithOpName("add"), bias_add, input_add); + + ops::Identity(fetch, add); + + auto input_tensor = GenerateRandomTensor( + TensorShape(input_shape.shape_.dim_sizes())); + auto input_add_tensor = GenerateRandomTensor( + TensorShape(input_shape_add.shape_.dim_sizes())); + auto filter_tensor = GenerateRandomTensor( + TensorShape(filter_shape.shape_.dim_sizes())); + auto bias_tensor = GenerateRandomTensor( + TensorShape(bias_shape.shape_.dim_sizes())); + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_tensor}, + {"filter", filter_tensor}, + {"bias", bias_tensor}, + {"input_add", input_add_tensor}}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + Remapper optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + auto fetch_node_name = "add"; + if (node.name() == fetch_node_name) { + EXPECT_EQ("_FusedMatMul", node.op()); + EXPECT_EQ("input", node.input(0)); + EXPECT_EQ("filter", node.input(1)); + + EXPECT_EQ(2, node.attr().at("num_args").i()); + EXPECT_EQ("bias", node.input(2)); + EXPECT_EQ("input_add", node.input(3)); + + const auto fused_ops = node.attr().at("fused_ops").list().s(); + EXPECT_EQ(2, fused_ops.size()); + EXPECT_EQ("BiasAdd", fused_ops[0]); + EXPECT_EQ("Add", fused_ops[1]); + found++; + } + } + EXPECT_EQ(1, found); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6); +} #endif // ENABLE_MKLDNN_V1 } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index b9bd6430991..d7705e91f52 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -1283,28 +1283,36 @@ Status AddFusedContractionNode(RemapperContext* ctx, const NodeDef& contraction = graph->node(matched.contraction); const NodeDef& bias_add = graph->node(matched.bias_add); - // MKL version only support fusion for Conv2D - DCHECK(IsConv2D(contraction)); + // MKL version only support fusion for Conv2D and MatMul + DCHECK(IsConv2D(contraction) || IsMatMul(contraction)); - NodeDef fused_conv2d; + NodeDef contraction_node; const NodeDef& add = graph->node(matched.add); - fused_conv2d.set_name(add.name()); - fused_conv2d.set_op(kFusedConv2D); - fused_conv2d.set_device(contraction.device()); - fused_conv2d.add_input(contraction.input(0)); // 0: input - fused_conv2d.add_input(contraction.input(1)); // 1: filter - fused_conv2d.add_input(bias_add.input(1)); // 2: bias + contraction_node.set_name(add.name()); + contraction_node.set_device(contraction.device()); + contraction_node.add_input( + contraction.input(0)); // 0: input(conv) / a (matmul) + contraction_node.add_input( + contraction.input(1)); // 1: filter(conv) / b (matmul) + contraction_node.add_input(bias_add.input(1)); // 2: bias - // Add OP has two inputs, one is conv+bias pattern matched previously, - // the other input to add is fused here. - fused_conv2d.add_input(add.input(1 - matched.port_id)); + // Add OP has two inputs, one is conv+bias/matmul+bias pattern matched + // previously, the other input to add is fused here. + contraction_node.add_input(add.input(1 - matched.port_id)); - CopyConv2DAttributes(contraction, &fused_conv2d); - SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add"}, 2); + if (IsConv2D(contraction)) { + contraction_node.set_op(kFusedConv2D); + CopyConv2DAttributes(contraction, &contraction_node); + } else if (IsMatMul(contraction)) { + contraction_node.set_op(kFusedMatMul); + CopyMatMulAttributes(contraction, &contraction_node); + } + + SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); Status status; - mutation->AddNode(std::move(fused_conv2d), &status); + mutation->AddNode(std::move(contraction_node), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -1621,19 +1629,25 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { } #ifdef INTEL_MKL -bool IsConv2DWithAdd(const RemapperContext& ctx, int node_index) { +bool IsConv2DOrMatMul(const NodeDef& node) { + return IsConv2D(node) || IsMatMul(node); +} + +bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) { const auto* node_view = ctx.graph_view.GetNode(node_index); const auto* node_def = node_view->node(); // Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion. + // MatMul + Add or MatMul + BiasAdd + Add fusion. auto is_supported_add_input = [](const auto* node_view) -> bool { - if (IsConv2D(*node_view->node())) return true; + // Currently only support Conv2D and MatMul + if (IsConv2DOrMatMul(*node_view->node())) return true; if (IsBiasAdd(*node_view->node())) { if (node_view->NumRegularFanins() < 2) return false; const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0); const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1); - return IsConv2D(*bias_add_fanin_0.node_view()->node()) || - IsConv2D(*bias_add_fanin_1.node_view()->node()); + return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) || + IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node()); } return false; }; @@ -1739,7 +1753,7 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) { #ifdef INTEL_MKL return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() || - IsConv2DWithAdd(ctx, node_index); + IsContractionWithAdd(ctx, node_index); #else return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() || is_batch_norm_fusion_candidate(); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc index 7bd47e9d014..9bb26535cbf 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc @@ -955,6 +955,71 @@ TEST_F(FilterCacheTest, Conv2DFilterCacheTest) { // Testing fusion of MatMul and BiasAdd template class MklFusedMatMulOpTest : public OpsTestBase { + private: + void RunMklFusedMatMulOp(const Tensor& input, const Tensor& weight, + const std::vector& args, + const std::vector& fused_ops, + Tensor* output) { + DataType dtype = DataTypeToEnum::v(); + const int num_args = args.size(); + if (!NativeFormatEnabled()) { + TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(num_args, DT_UINT8)) + .Attr("T", dtype) + .Attr("transpose_a", false) + .Attr("transpose_b", false) + .Attr("num_args", num_args) + .Attr("fused_ops", fused_ops) + .Attr("epsilon", 0.0001) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + } else { + TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Attr("T", dtype) + .Attr("transpose_a", false) + .Attr("transpose_b", false) + .Attr("num_args", num_args) + .Attr("fused_ops", fused_ops) + .Attr("epsilon", 0.0001) + .Attr("_kernel", "MklNameChangeOp") + .Finalize(node_def())); + } + + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(input.shape(), input.flat()); + AddInputFromArray(weight.shape(), weight.flat()); + for (const Tensor& arg : args) + AddInputFromArray(arg.shape(), arg.flat()); + if (!NativeFormatEnabled()) { + // Add MKL meta input for input, filter and bias. + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + for (const Tensor& arg : args) + AddInputFromArray(dummy_shape, dummy_tensor); + } + + TF_ASSERT_OK(RunOpKernel()); + + const Tensor& output_tensor = *GetOutput(0); + if (!NativeFormatEnabled()) { + const Tensor& output_meta_tensor = *GetOutput(1); + CommonTestUtilities test_util; + test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, + output); + } else { + *output = output_tensor; + } + } + protected: void VerifyFusedMatMul(const int kBatch, const int kInputChannel, const int kOutputChannel, @@ -1002,70 +1067,24 @@ class MklFusedMatMulOpTest : public OpsTestBase { next_op = ops::Tanh(root.WithOpName(last_op), next_op); } + if (std::find(fused_ops.begin(), fused_ops.end(), "Add") != + fused_ops.end()) { + last_op = "with_add"; + next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op); + } + CommonTestUtilities::RunAndFetch(root, last_op, output); }; const FusedGraphRunner run_fused = [this](const Tensor& input, const Tensor& weight, const Tensor& bias, const std::vector& fused_ops, Tensor* output) { - DataType dtype = DataTypeToEnum::v(); - const int num_args = 1; - - if (!NativeFormatEnabled()) { - TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul") - .Input(FakeInput(dtype)) - .Input(FakeInput(dtype)) - .Input(FakeInput(num_args, dtype)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(num_args, DT_UINT8)) - .Attr("T", dtype) - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("num_args", num_args) - .Attr("fused_ops", fused_ops) - .Attr("epsilon", 0.0001) - .Attr("_kernel", "MklLayoutDependentOp") - .Finalize(node_def())); - } else { - TF_EXPECT_OK( - NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul") - .Input(FakeInput(dtype)) - .Input(FakeInput(dtype)) - .Input(FakeInput(num_args, dtype)) - .Attr("T", dtype) - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("num_args", num_args) - .Attr("fused_ops", fused_ops) - .Attr("epsilon", 0.0001) - .Attr("_kernel", "MklNameChangeOp") - .Finalize(node_def())); - } - - TF_EXPECT_OK(InitOp()); - - AddInputFromArray(input.shape(), input.flat()); - AddInputFromArray(weight.shape(), weight.flat()); - AddInputFromArray(bias.shape(), bias.flat()); - if (!NativeFormatEnabled()) { - // Add MKL meta input for input, filter and bias. - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); - } - - TF_ASSERT_OK(RunOpKernel()); - - const Tensor& output_tensor = *GetOutput(0); - if (!NativeFormatEnabled()) { - const Tensor& output_meta_tensor = *GetOutput(1); - CommonTestUtilities test_util; - test_util.PerformConversion(dtype, output_tensor, - output_meta_tensor, output); - } else { - *output = output_tensor; + std::vector fused_input = {bias}; + if (std::find(fused_ops.begin(), fused_ops.end(), "Add") != + fused_ops.end()) { + fused_input.push_back(input); } + RunMklFusedMatMulOp(input, weight, fused_input, fused_ops, output); }; CommonTestUtilities::VerifyFusedMatrixClose(kInputChannel, kBatch, @@ -1120,12 +1139,22 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndTanh) { {"BiasAdd", "Tanh"}); } +TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) { + const int batch = 3; + const int input_channel = 4; + const int output_channel = 4; + + this->VerifyFusedMatMul(batch, input_channel, output_channel, + {"BiasAdd", "Add"}); +} + REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, // WithBias, // WithBiasAndRelu, // WithBiasAndRelu6, // WithBiasAndElu, // - WithBiasAndTanh); + WithBiasAndTanh, // + WithBiasAndAdd); using MklFusedMatMulDataTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest, diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc index 905abbfeef2..246efacb615 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc @@ -45,6 +45,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { ctx, fused_ops_[0] == "BiasAdd", errors::InvalidArgument( "The 1st post-argument of MklFusedMatMul must be BiasAdd.")); + if (fused_ops_.size() > 1 && fused_ops_[1] == "Add") fuse_add_ = true; OP_REQUIRES( ctx, transpose_a_ == false, errors::InvalidArgument("In[0] of MklMatMul can't be transposed.")); @@ -114,7 +115,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { // 2. var, keep the original format to avoid reordering. MklDnnMatMulFwdParams matmul_params( src_dims, weight_dims, bias_dims, dst_dims, src_format, - (this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format); + (this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format, + MEMORY_FORMAT::nc); // Extend the basic parameters for data types and fusions. ExtendMklDnnMatMulFwdParams(ctx, matmul_params); @@ -126,15 +128,70 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { std::shared_ptr matmul_pd = matmul_prim->GetPrimitiveDesc(); - if (src_mkl_shape.IsMklTensor()) { - this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims, - MKL_TENSOR_FORMAT_NC, &dst_tensor); + // The output shape of MatMul is same both for MKL and TF version. + // They are all NC format, no matter what's the format of input. + // And the shape of AddOp is also the same with output's shape. + auto dst_pd = matmul_pd->PRIMITIVE_DESC_DST; + + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(false); + + TensorShape output_tf_shape({batch, channel}); + + if (fuse_add_) { + const Tensor& add_tensor = MklGetInput(ctx, kInputIndex_Add); + MklDnnShape add_mkl_shape; + GetMklShape(ctx, kInputIndex_Add, &add_mkl_shape, native_format); + + // For native format, we need not to set metadata. + if (native_format && ctx->forward_input_to_output_with_shape( + kInputIndex_Add, kOutputIndex_Dst, + output_tf_shape, &dst_tensor)) { + ; // Need to do nothing for native format + } else if (!native_format && ForwardMklTensorInToOutWithMklShape( + ctx, kInputIndex_Add, kOutputIndex_Dst, + &dst_tensor, output_mkl_shape, false)) { + ; // If it's not native format, need to forward and set meta first + } else { + // If forward is not successful, we should use reorder to copy add + // tensor to dst tensor + AllocateOutputSetMklShape(ctx, kOutputIndex_Dst, &dst_tensor, + output_tf_shape, output_mkl_shape, + native_format); + auto output_format_tag = + MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC); + auto add_md = + add_mkl_shape.IsMklTensor() + ? add_mkl_shape.GetMklLayout() + : memory::desc(dst_dims, MklDnnType(), output_format_tag); + auto dst_md = + memory::desc(dst_dims, MklDnnType(), output_format_tag); + + void* add_buf = + static_cast(const_cast(add_tensor.flat().data())); + void* dst_buf = static_cast((dst_tensor)->flat().data()); + + if (native_format) { + // We are simply deep copying the add_tensor to dst_tensor without + // changing memory layout, hence using same memory descriptor. + add_md = dst_md = + memory::desc({add_tensor.NumElements()}, MklDnnType(), + mkldnn::memory::format_tag::x); + } + + auto fuse_add_src_ = + MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf); + auto fuse_add_dst_ = + MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf); + auto reorder_desc = + REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_); + + CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_, + this->cpu_engine_, ctx); + } } else { - TensorShape dst_tensor_shape({batch, channel}); - MklDnnShape dst_mkl_shape; - dst_mkl_shape.SetMklTensor(false); - AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape, - dst_mkl_shape, native_format); + AllocateOutputSetMklShape(ctx, 0, &dst_tensor, output_tf_shape, + output_mkl_shape, native_format); } // if there's nothing to compute, just return. @@ -228,6 +285,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}}); } else if (post_op == "Tanh") { params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}}); + } else if (post_op == "Add") { + params.post_op_params.push_back({"sum", {1.0}}); } else { OP_REQUIRES_OK( ctx, errors::InvalidArgument( @@ -237,10 +296,13 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { } private: + bool fuse_add_ = false; bool transpose_a_; bool transpose_b_; std::vector fused_ops_; -}; + const int kInputIndex_Add = 3; + const int kOutputIndex_Dst = 0; +}; // namespace tensorflow // Register mkl kernels for supported operations and types. #define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \ diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index d1e82bf6f02..375047d290f 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -48,8 +48,9 @@ struct MklDnnMatMulFwdParams { memory::dims weight_dims; memory::dims bias_dims; memory::dims dst_dims; - memory::format_tag src_format; - memory::format_tag weight_format; + MEMORY_FORMAT src_format; + MEMORY_FORMAT weight_format; + MEMORY_FORMAT dst_format; string dtypes = string(""); struct PostOpParam { string name; @@ -57,17 +58,18 @@ struct MklDnnMatMulFwdParams { }; std::vector post_op_params; - MklDnnMatMulFwdParams( - memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims, - memory::dims dst_dims, - memory::format_tag src_format = memory::format_tag::any, - memory::format_tag weight_format = memory::format_tag::any) + MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims, + memory::dims bias_dims, memory::dims dst_dims, + MEMORY_FORMAT src_format = MEMORY_FORMAT::any, + MEMORY_FORMAT weight_format = MEMORY_FORMAT::any, + MEMORY_FORMAT dst_format = MEMORY_FORMAT::any) : src_dims(src_dims), weight_dims(weight_dims), bias_dims(bias_dims), dst_dims(dst_dims), src_format(src_format), - weight_format(weight_format) {} + weight_format(weight_format), + dst_format(dst_format) {} }; // With quantization, input, weight, bias, and output can have different types. @@ -184,7 +186,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims}, MklDnnType(), - memory::format_tag::any)); + matmul_fwd_params.dst_format)); context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims}, MklDnnType(), @@ -236,11 +238,17 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { std::vector scales; scales.push_back(post_op_param.param[0]); post_ops_attr.set_output_scales(0, scales); + } else if (post_op_param.name == "sum") { + DCHECK_EQ(post_op_param.param.size(), 1); + float op_scale = post_op_param.param[0]; + post_ops.append_sum(op_scale); + } else { DCHECK((post_op_param.name == "relu") || (post_op_param.name == "relu6") || (post_op_param.name == "elu") || (post_op_param.name == "tanh") || + (post_op_param.name == "sum") || (post_op_param.name == "output_scale")); } } @@ -340,6 +348,10 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(post_op_param.param[0]); key_creator.AddAsKey(post_op_param.param[1]); key_creator.AddAsKey(post_op_param.param[2]); + } else if (post_op_param.name == "sum") { + DCHECK_EQ(post_op_param.param.size(), 1); + key_creator.AddAsKey(post_op_param.name); + key_creator.AddAsKey(post_op_param.param[0]); } else if (post_op_param.name == "output_scale") { DCHECK_EQ(post_op_param.param.size(), 1); key_creator.AddAsKey(post_op_param.name);