Add support for Addv2
This commit is contained in:
parent
f67991359e
commit
199d79d32e
@ -206,6 +206,7 @@ static inline bool IsMklElementWiseOp(const string& op_name, DataType T) {
|
||||
return false;
|
||||
}
|
||||
bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
|
||||
0 == op_name.compare(GetMklOpName("AddV2")) ||
|
||||
0 == op_name.compare(GetMklOpName("Sub")) ||
|
||||
0 == op_name.compare(GetMklOpName("Mul")) ||
|
||||
0 == op_name.compare(GetMklOpName("Maximum")) ||
|
||||
|
@ -350,6 +350,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
|
||||
// MklInputConversion op is added before it.
|
||||
csinfo_.add = "Add";
|
||||
csinfo_.add_v2 = "AddV2";
|
||||
csinfo_.maximum = "Maximum";
|
||||
csinfo_.mul = "Mul";
|
||||
csinfo_.squared_difference = "SquaredDifference";
|
||||
@ -364,6 +365,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back({csinfo_.add_v2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.add_v2),
|
||||
CopyAttrsAll, RewriteIfAtleastOneMklInput,
|
||||
kRewriteForLayoutPropagation});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
@ -867,6 +872,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
|
||||
typedef struct {
|
||||
string addn;
|
||||
string add;
|
||||
string add_v2;
|
||||
string avg_pool;
|
||||
string avg_pool_grad;
|
||||
string avg_pool3d;
|
||||
|
@ -3776,6 +3776,65 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
|
||||
"B->D:1;C->D:2;D->E:1");
|
||||
}
|
||||
|
||||
// The following positive and negative tests test the rewrite of Add and Addv2
|
||||
// to MKL versions. The operators will be rewritten only if one of the inputs
|
||||
// comes from another MKL operator.
|
||||
TEST_F(MklLayoutPassTest, PositiveRewriteAdd) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'Relu'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A']}"
|
||||
"node { name: 'N' op: 'Add'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['M', 'B']}");
|
||||
EXPECT_EQ(
|
||||
DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAdd)"
|
||||
"|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;"
|
||||
"M:1->N:2;M:control->DMT/_1:control");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NegativeRewriteAdd) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'Add'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);N(Add)|A->N;B->N:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, PositiveRewriteAddV2) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'Relu'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A']}"
|
||||
"node { name: 'N' op: 'AddV2'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['M', 'B']}");
|
||||
EXPECT_EQ(
|
||||
DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAddV2)"
|
||||
"|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;"
|
||||
"M:1->N:2;M:control->DMT/_1:control");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NegativeRewriteAddV2) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'AddV2'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);N(AddV2)|A->N;B->N:1");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Post-rewrite fixup pass test
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
@ -70,6 +70,8 @@ class MklBinaryOp : public BinaryOp<Device, Functor> {
|
||||
|
||||
REGISTER6(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double,
|
||||
int32, int64, bfloat16);
|
||||
REGISTER6(MklBinaryOp, CPU, "_MklAddV2", functor::add, float, Eigen::half,
|
||||
double, int32, int64, bfloat16);
|
||||
REGISTER8(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double,
|
||||
int32, int64, complex64, complex128, bfloat16);
|
||||
REGISTER6(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double,
|
||||
|
@ -397,6 +397,7 @@ REGISTER_OP("AddV2")
|
||||
.SetIsAggregate()
|
||||
.SetIsCommutative();
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
REGISTER_OP("_MklAdd")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
@ -415,6 +416,21 @@ Returns `x` + `y` element-wise.
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_MklAddV2")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Input("mkl_x: uint8")
|
||||
.Input("mkl_y: uint8")
|
||||
.Output("z: T")
|
||||
.Output("mkl_z: uint8")
|
||||
.Attr(
|
||||
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
|
||||
"complex64, complex128}")
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
||||
.SetIsAggregate()
|
||||
.SetIsCommutative();
|
||||
#endif // INTEL_MKL
|
||||
|
||||
REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
|
||||
shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user