Add support for Addv2

This commit is contained in:
AG Ramesh 2019-08-30 13:24:19 -07:00 committed by Penporn Koanantakool
parent f67991359e
commit 199d79d32e
5 changed files with 84 additions and 0 deletions

View File

@ -206,6 +206,7 @@ static inline bool IsMklElementWiseOp(const string& op_name, DataType T) {
return false;
}
bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
0 == op_name.compare(GetMklOpName("AddV2")) ||
0 == op_name.compare(GetMklOpName("Sub")) ||
0 == op_name.compare(GetMklOpName("Mul")) ||
0 == op_name.compare(GetMklOpName("Maximum")) ||

View File

@ -350,6 +350,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
// MklInputConversion op is added before it.
csinfo_.add = "Add";
csinfo_.add_v2 = "AddV2";
csinfo_.maximum = "Maximum";
csinfo_.mul = "Mul";
csinfo_.squared_difference = "SquaredDifference";
@ -364,6 +365,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.add_v2,
mkl_op_registry::GetMklOpName(csinfo_.add_v2),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
kRewriteForLayoutPropagation});
rinfo_.push_back(
{csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
@ -867,6 +872,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
typedef struct {
string addn;
string add;
string add_v2;
string avg_pool;
string avg_pool_grad;
string avg_pool3d;

View File

@ -3776,6 +3776,65 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
"B->D:1;C->D:2;D->E:1");
}
// The following positive and negative tests test the rewrite of Add and Addv2
// to MKL versions. The operators will be rewritten only if one of the inputs
// comes from another MKL operator.
TEST_F(MklLayoutPassTest, PositiveRewriteAdd) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A']}"
"node { name: 'N' op: 'Add'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['M', 'B']}");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAdd)"
"|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;"
"M:1->N:2;M:control->DMT/_1:control");
}
TEST_F(MklLayoutPassTest, NegativeRewriteAdd) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'N' op: 'Add'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);N(Add)|A->N;B->N:1");
}
TEST_F(MklLayoutPassTest, PositiveRewriteAddV2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A']}"
"node { name: 'N' op: 'AddV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['M', 'B']}");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAddV2)"
"|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;"
"M:1->N:2;M:control->DMT/_1:control");
}
TEST_F(MklLayoutPassTest, NegativeRewriteAddV2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'N' op: 'AddV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);N(AddV2)|A->N;B->N:1");
}
/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test
/////////////////////////////////////////////////////////////////////

View File

@ -70,6 +70,8 @@ class MklBinaryOp : public BinaryOp<Device, Functor> {
REGISTER6(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double,
int32, int64, bfloat16);
REGISTER6(MklBinaryOp, CPU, "_MklAddV2", functor::add, float, Eigen::half,
double, int32, int64, bfloat16);
REGISTER8(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double,
int32, int64, complex64, complex128, bfloat16);
REGISTER6(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double,

View File

@ -397,6 +397,7 @@ REGISTER_OP("AddV2")
.SetIsAggregate()
.SetIsCommutative();
#ifdef INTEL_MKL
REGISTER_OP("_MklAdd")
.Input("x: T")
.Input("y: T")
@ -415,6 +416,21 @@ Returns `x` + `y` element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
)doc");
REGISTER_OP("_MklAddV2")
.Input("x: T")
.Input("y: T")
.Input("mkl_x: uint8")
.Input("mkl_y: uint8")
.Output("z: T")
.Output("mkl_z: uint8")
.Attr(
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
"complex64, complex128}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.SetIsAggregate()
.SetIsCommutative();
#endif // INTEL_MKL
REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
shape_inference::BroadcastBinaryOpShapeFn);