Introduce functional<->region conversion passes around extract outside compilation

- Follow functional->region transformation with a inlining pass to make sure calls
  generated by the transform get inlined.

PiperOrigin-RevId: 324282111
Change-Id: Ifaacec3d8919f390fdeda8ca9af129d8e7dce086
This commit is contained in:
Rahul Joshi 2020-07-31 14:12:22 -07:00 committed by TensorFlower Gardener
parent 1edbacaacc
commit 4e11812725

View File

@ -82,15 +82,23 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
// Run shape inference so that tf_executor/tf_device ops created later will
// likely to inherit more concrete types.
pm.addPass(TF::CreateTFShapeInferencePass());
OpPassManager &func_pm = pm.nest<FuncOp>();
func_pm.addPass(CreateTPUClusterFormationPass());
// Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
// because DecomposeResourceOpsPass uses pattern rewriter which hoists
// changed constants out of tf_device.Launch.
func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
func_pm.addPass(CreateTPUHostComputationExpansionPass());
pm.addNestedPass<FuncOp>(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
// Encode this in its own scope so that func_pm is not mistakenly used
// later on.
{
OpPassManager &func_pm = pm.nest<FuncOp>();
func_pm.addPass(CreateTPUClusterFormationPass());
// Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
// because DecomposeResourceOpsPass uses pattern rewriter which hoists
// changed constants out of tf_device.Launch.
func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
func_pm.addPass(CreateTPUHostComputationExpansionPass());
func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
}
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
pm.addPass(mlir::createInlinerPass());
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
// Run another shape inference pass because resource decomposition might have
// created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass());