Migrate TPUExtractOutsideCompilation conversion pass to use declarative pass registration instead of manually defined pass registration (NFC).
Pass documentation is migrated to the declarative pass spec. PiperOrigin-RevId: 347670003 Change-Id: Id4e9344efb7a7ef2920989a5b0f400ce160252dd
This commit is contained in:
parent
ce477dd2bb
commit
5f08955b75
@ -154,3 +154,59 @@ func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, t
|
||||
|
||||
let constructor = "TFTPU::CreateTPUClusterFormationPass()";
|
||||
}
|
||||
|
||||
def TPUExtractOutsideCompilationPass : Pass<"tf-tpu-extract-outside-compilation", "ModuleOp"> {
|
||||
let summary = "Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.";
|
||||
|
||||
let description = [{
|
||||
This pass extracts a CPU computation cluster with `_xla_outside_compilation`
|
||||
annotation, which denotes ops that should be run on CPU/host, from a TPU cluster.
|
||||
Each outside compilation cluster is moved to
|
||||
a tf_device.parallel_execute region. The TPU cluster is also moved to a
|
||||
tf_device.parallel_execute region. Communication ops between device and host are
|
||||
added to pass inputs/outputs to/from the outside compiled region.
|
||||
|
||||
For example, the following tf_device.cluster with an op marked for `xla_outside_compilation`:
|
||||
|
||||
```mlir
|
||||
func @outside_compilation() -> tensor<f32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
|
||||
%2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>)
|
||||
%3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>)
|
||||
tf_device.return %3 : tensor<f32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
```
|
||||
|
||||
will become a tf_device.parallel_execute op with a CPU/host region and
|
||||
a tf_device.cluster with communication ops to send data to/from device/host:
|
||||
|
||||
```mlir
|
||||
func @outside_compilation() -> tensor<f32> {
|
||||
%0 = "tf_device.parallel_execute"() ( {
|
||||
"tf_device.launch"() ( {
|
||||
%1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf.string>
|
||||
%2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf.string>) -> tensor<f32>
|
||||
%3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
|
||||
"tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
|
||||
tf_device.return
|
||||
}, {
|
||||
%1 = "tf_device.cluster"() ( {
|
||||
%2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32>
|
||||
%4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
tf_device.return %4 : tensor<f32>
|
||||
}) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32>
|
||||
tf_device.return %1 : tensor<f32>
|
||||
}) : () -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let constructor = "TFTPU::CreateTPUExtractOutsideCompilationPass()";
|
||||
}
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
|
||||
|
||||
@ -53,39 +54,9 @@ constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
|
||||
using OutsideClusterMap =
|
||||
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>;
|
||||
|
||||
// This pass extracts a CPU computation cluster with `_xla_outside_compilation`
|
||||
// annotation from a TPU cluster. Each outside compilation cluster is moved to
|
||||
// a parallel_execute region. The TPU cluster is also moved to a
|
||||
// parallel_execute region. Communication ops between device and host are
|
||||
// added to pass inputs/outputs to/from the outside compiled region.
|
||||
//
|
||||
// A simple example:
|
||||
// "tf_device.cluster"() ( {
|
||||
// "tf.A"()
|
||||
// "tf.B"() {_xla_outside_compilation = "cluster1"}
|
||||
// "tf.C"()
|
||||
// tf_device.return
|
||||
// }) {num_cores_per_replica = 1, topology = "", device_assignment = []}
|
||||
//
|
||||
// Would become the following ops (unimportant attribute, type are omitted):
|
||||
// "tf_device.parallel_execute"() ( {
|
||||
// "tf_device.launch"() ( {
|
||||
// "tf.B()
|
||||
// tf_device.return
|
||||
// })
|
||||
// tf_device.return
|
||||
// }, {
|
||||
// "tf_device.cluster"( {
|
||||
// "tf.A"()
|
||||
// "tf.C"()
|
||||
// tf_device.return
|
||||
// })
|
||||
// tf_device.return
|
||||
// })
|
||||
|
||||
struct TPUExtractOutsideCompilation
|
||||
: public PassWrapper<TPUExtractOutsideCompilation,
|
||||
OperationPass<ModuleOp>> {
|
||||
: public TF::TPUExtractOutsideCompilationPassBase<
|
||||
TPUExtractOutsideCompilation> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
@ -893,9 +864,5 @@ CreateTPUExtractOutsideCompilationPass() {
|
||||
return std::make_unique<TPUExtractOutsideCompilation>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUExtractOutsideCompilation> pass(
|
||||
"tf-tpu-extract-outside-compilation",
|
||||
"Extracts TPU outside compilation to separate parallel_execute.");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
||||
|
Loading…
Reference in New Issue
Block a user