Merge pull request #43573 from Intel-tensorflow:utfixes1
PiperOrigin-RevId: 336370647 Change-Id: I1973e6ce009c667484accdca14d0c69e4263d282
This commit is contained in:
commit
603b945f8a
@ -408,7 +408,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.concatv2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.concatv2),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.conjugate_transpose,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
|
||||
@ -553,7 +553,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_concatv2,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
|
||||
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
|
||||
CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
|
||||
rinfo_.push_back({csinfo_.quantized_conv2d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
|
||||
CopyAttrsQuantizedConv2D, AlwaysRewrite,
|
||||
@ -1501,6 +1501,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// For oneDNN, only int32 is supported for axis data type
|
||||
static bool ConcatV2Rewrite(const Node* n) {
|
||||
DataType T;
|
||||
GetNodeAttr(n->def(), "Tidx", &T);
|
||||
return (T == DT_INT32);
|
||||
}
|
||||
|
||||
static bool DequantizeRewrite(const Node* n) {
|
||||
DCHECK(n);
|
||||
|
@ -4258,6 +4258,28 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
|
||||
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_IndexTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Const' "
|
||||
" attr { key: 'dtype' value { type: DT_INT64 } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'B' op: 'InputList'"
|
||||
" attr { key: 'N' value { i: 2 } }}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'ConcatV2'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'Tidx' value { type: DT_INT64 } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['B:0', 'B:1', 'A']}"
|
||||
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
|
||||
"A->D:2;B->D;B:1->D:1;C->E;D->E:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
|
Loading…
Reference in New Issue
Block a user