From 4e118127252147ab755b1f92c8127f93564205f8 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Fri, 31 Jul 2020 14:12:22 -0700 Subject: [PATCH] Introduce functional<->region conversion passes around extract outside compilation - Follow functional->region transformation with a inlining pass to make sure calls generated by the transform get inlined. PiperOrigin-RevId: 324282111 Change-Id: Ifaacec3d8919f390fdeda8ca9af129d8e7dce086 --- .../mlir/tensorflow/transforms/bridge.cc | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index cb1dd2332a8..ed0528ae054 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -82,15 +82,23 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { // Run shape inference so that tf_executor/tf_device ops created later will // likely to inherit more concrete types. pm.addPass(TF::CreateTFShapeInferencePass()); - OpPassManager &func_pm = pm.nest(); - func_pm.addPass(CreateTPUClusterFormationPass()); - // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass - // because DecomposeResourceOpsPass uses pattern rewriter which hoists - // changed constants out of tf_device.Launch. - func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass()); - func_pm.addPass(CreateTPUHostComputationExpansionPass()); - pm.addNestedPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass()); + // Encode this in its own scope so that func_pm is not mistakenly used + // later on. + { + OpPassManager &func_pm = pm.nest(); + func_pm.addPass(CreateTPUClusterFormationPass()); + // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass + // because DecomposeResourceOpsPass uses pattern rewriter which hoists + // changed constants out of tf_device.Launch. + func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass()); + func_pm.addPass(CreateTPUHostComputationExpansionPass()); + func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass()); + } + pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); + pm.addPass(mlir::createInlinerPass()); pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); + pm.addPass(TF::CreateTFRegionControlFlowToFunctional()); + // Run another shape inference pass because resource decomposition might have // created new partial types. pm.addPass(TF::CreateTFShapeInferencePass());