diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 6eefdf16067..8558dd90fca 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -97,13 +97,19 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { func_pm.addPass(CreateTPUHostComputationExpansionPass()); func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass()); } - // Run another shape inference pass because resource decomposition might have - // created new partial types. - pm.addPass(TF::CreateTFShapeInferencePass()); + // Note that the region-based control-flow produced here still contains // function call ops which get inlined by the subsequent inliner pass. pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); pm.addPass(mlir::createInlinerPass()); + pm.addNestedPass( + TF::CreateDropWhileShapeInvariantInDeviceClusterPass()); + // Run another shape inference pass because resource decomposition might have + // created new partial types. Also, after dropping `shape_invariant` attribute + // from While/WhileRegion ops within cluster would lead to more precise + // shapes. + pm.addPass(TF::CreateTFShapeInferencePass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addPass(CreateTPUClusterCleanupAttributesPass()); pm.addPass(TFDevice::CreateResourceOpLiftingPass()); pm.addNestedPass(createCSEPass());