diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 5cb15027fc5..a34be28c809 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -281,7 +281,8 @@ std::unique_ptr> CreateTPUHostComputationExpansionPass(); // Creates a pass that extract outside compilation (CPU ops inside TPU cluster) // ops to a separate parallel_execute region to run on CPU. -std::unique_ptr> CreateTPUExtractOutsideCompilationPass(); +std::unique_ptr> +CreateTPUExtractOutsideCompilationPass(); // Populates the supplied passmanager with the passes required to run the void CreateTPUBridgePipeline(OpPassManager& pm); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index a2a19108326..503c9869557 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -49,8 +49,9 @@ using OutsideClusterMap = // TODO(b/154363171): Add example tranformations. struct TPUExtractOutsideCompilation - : public PassWrapper { - void runOnFunction() override; + : public PassWrapper> { + void runOnOperation() override; }; // Collects and clusters ops in `block` with the same `_xla_outside_compilation` @@ -305,9 +306,9 @@ void CreateParallelExecuteFromOutsideClusters( } } -void TPUExtractOutsideCompilation::runOnFunction() { +void TPUExtractOutsideCompilation::runOnOperation() { auto extract_result = - getFunction().walk([&](tf_device::ClusterOp tpu_cluster) { + getOperation().walk([&](tf_device::ClusterOp tpu_cluster) { OutsideClusterMap clusters; if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(), &clusters))) @@ -325,7 +326,7 @@ void TPUExtractOutsideCompilation::runOnFunction() { } // namespace -std::unique_ptr> +std::unique_ptr> CreateTPUExtractOutsideCompilationPass() { return std::make_unique(); }