Enable Squeeze fusion in grappler

This commit is contained in:
Mahmoud Abuzaina 2021-01-20 11:39:00 -08:00
parent e96e745fe3
commit afd8d97ae1
2 changed files with 7 additions and 7 deletions

View File

@ -1867,13 +1867,10 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
continue; continue;
} }
// NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do // NOTE: We can only fuse BatchNorm into Conv2D nodes. In theory we can do
// it for MatMul as well, but in practice this pattern does not appear in // it for MatMul as well, but in practice this pattern does not appear in
// real Tensorflow graphs. // real Tensorflow graphs.
// TODO(penporn):
// Remove this once TF-MKL supports _FusedConv2D with these operations.
#ifndef INTEL_MKL
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze. // Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias; ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&
@ -1884,6 +1881,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
continue; continue;
} }
// TODO(intel-tf):
// Remove this once TF-MKL supports _FusedConv2D with these operations.
#ifndef INTEL_MKL
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D; // Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
ContractionWithBatchNorm contract_with_batch_norm; ContractionWithBatchNorm contract_with_batch_norm;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&

View File

@ -932,6 +932,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6); test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
} }
} }
#endif // !INTEL_MKL
TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) { TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
using ops::Placeholder; using ops::Placeholder;
@ -1003,7 +1004,6 @@ TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
ASSERT_EQ(tensors.size(), 1); ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6); test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
} }
#endif // !INTEL_MKL
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow