Merge pull request #43573 from Intel-tensorflow:utfixes1

PiperOrigin-RevId: 336370647
Change-Id: I1973e6ce009c667484accdca14d0c69e4263d282
This commit is contained in:
TensorFlower Gardener 2020-10-09 15:07:28 -07:00
commit 603b945f8a
2 changed files with 30 additions and 2 deletions

View File

@ -408,7 +408,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.concatv2,
mkl_op_registry::GetMklOpName(csinfo_.concatv2),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
rinfo_.push_back(
{csinfo_.conjugate_transpose,
mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
@ -553,7 +553,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_concatv2,
mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.quantized_conv2d,
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
CopyAttrsQuantizedConv2D, AlwaysRewrite,
@ -1501,6 +1501,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
}
return false;
}
// For oneDNN, only int32 is supported for axis data type
static bool ConcatV2Rewrite(const Node* n) {
DataType T;
GetNodeAttr(n->def(), "Tidx", &T);
return (T == DT_INT32);
}
static bool DequantizeRewrite(const Node* n) {
DCHECK(n);

View File

@ -4258,6 +4258,28 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_IndexTest) {
InitGraph(
"node { name: 'A' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT64 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'B' op: 'InputList'"
" attr { key: 'N' value { i: 2 } }}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: 'ConcatV2'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'Tidx' value { type: DT_INT64 } }"
" attr { key: 'N' value { i: 2 } }"
" input: ['B:0', 'B:1', 'A']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
}
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
InitGraph(
"node { name: 'A' op: 'Input'}"