Introduce functional<->region conversion passes around extract outside compilation

- Follow functional->region transformation with a inlining pass to make sure calls
  generated by the transform get inlined.

PiperOrigin-RevId: 324282111
Change-Id: Ifaacec3d8919f390fdeda8ca9af129d8e7dce086
This commit is contained in:
Rahul Joshi 2020-07-31 14:12:22 -07:00 committed by TensorFlower Gardener
parent 1edbacaacc
commit 4e11812725

View File

@ -82,15 +82,23 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
// Run shape inference so that tf_executor/tf_device ops created later will // Run shape inference so that tf_executor/tf_device ops created later will
// likely to inherit more concrete types. // likely to inherit more concrete types.
pm.addPass(TF::CreateTFShapeInferencePass()); pm.addPass(TF::CreateTFShapeInferencePass());
OpPassManager &func_pm = pm.nest<FuncOp>(); // Encode this in its own scope so that func_pm is not mistakenly used
func_pm.addPass(CreateTPUClusterFormationPass()); // later on.
// Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass {
// because DecomposeResourceOpsPass uses pattern rewriter which hoists OpPassManager &func_pm = pm.nest<FuncOp>();
// changed constants out of tf_device.Launch. func_pm.addPass(CreateTPUClusterFormationPass());
func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass()); // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
func_pm.addPass(CreateTPUHostComputationExpansionPass()); // because DecomposeResourceOpsPass uses pattern rewriter which hoists
pm.addNestedPass<FuncOp>(CreateTPUUpdateEmbeddingEnqueueOpInputsPass()); // changed constants out of tf_device.Launch.
func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
func_pm.addPass(CreateTPUHostComputationExpansionPass());
func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
}
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
pm.addPass(mlir::createInlinerPass());
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
// Run another shape inference pass because resource decomposition might have // Run another shape inference pass because resource decomposition might have
// created new partial types. // created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass()); pm.addPass(TF::CreateTFShapeInferencePass());