Enable pass to drop shape_invariant attribute from while ops in bridge and perform shape inference and canonicalization after it.

XLA only supports While/WhileRegion ops where the operand and result shapes match. Since we are compiling the clusters to XLA in bridge pipeline, we can drop the attribute and have better shape inference and canonicalization of While/WhileRegion ops.

PiperOrigin-RevId: 360280863
Change-Id: I88bf71af751d6561f264df77386c6fe493ffae5f
This commit is contained in:
Prakalp Srivastava 2021-03-01 14:40:53 -08:00 committed by TensorFlower Gardener
parent 6aca90258a
commit d10071f208

View File

@ -97,13 +97,19 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
func_pm.addPass(CreateTPUHostComputationExpansionPass());
func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
}
// Run another shape inference pass because resource decomposition might have
// created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass());
// Note that the region-based control-flow produced here still contains
// function call ops which get inlined by the subsequent inliner pass.
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
pm.addPass(mlir::createInlinerPass());
pm.addNestedPass<FuncOp>(
TF::CreateDropWhileShapeInvariantInDeviceClusterPass());
// Run another shape inference pass because resource decomposition might have
// created new partial types. Also, after dropping `shape_invariant` attribute
// from While/WhileRegion ops within cluster would lead to more precise
// shapes.
pm.addPass(TF::CreateTFShapeInferencePass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addPass(CreateTPUClusterCleanupAttributesPass());
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
pm.addNestedPass<FuncOp>(createCSEPass());