Migrate TPUExtractOutsideCompilation conversion pass to use declarative pass registration instead of manually defined pass registration (NFC).

Pass documentation is migrated to the declarative pass spec.

PiperOrigin-RevId: 347670003
Change-Id: Id4e9344efb7a7ef2920989a5b0f400ce160252dd
This commit is contained in:
Ken Franko 2020-12-15 12:32:54 -08:00 committed by TensorFlower Gardener
parent ce477dd2bb
commit 5f08955b75
2 changed files with 59 additions and 36 deletions

View File

@ -154,3 +154,59 @@ func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, t
let constructor = "TFTPU::CreateTPUClusterFormationPass()";
}
def TPUExtractOutsideCompilationPass : Pass<"tf-tpu-extract-outside-compilation", "ModuleOp"> {
let summary = "Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.";
let description = [{
This pass extracts a CPU computation cluster with `_xla_outside_compilation`
annotation, which denotes ops that should be run on CPU/host, from a TPU cluster.
Each outside compilation cluster is moved to
a tf_device.parallel_execute region. The TPU cluster is also moved to a
tf_device.parallel_execute region. Communication ops between device and host are
added to pass inputs/outputs to/from the outside compiled region.
For example, the following tf_device.cluster with an op marked for `xla_outside_compilation`:
```mlir
func @outside_compilation() -> tensor<f32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
%2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>)
%3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>)
tf_device.return %3 : tensor<f32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<f32>
return %0 : tensor<f32>
}
```
will become a tf_device.parallel_execute op with a CPU/host region and
a tf_device.cluster with communication ops to send data to/from device/host:
```mlir
func @outside_compilation() -> tensor<f32> {
%0 = "tf_device.parallel_execute"() ( {
"tf_device.launch"() ( {
%1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf.string>
%2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf.string>) -> tensor<f32>
%3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
"tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
tf_device.return
}, {
%1 = "tf_device.cluster"() ( {
%2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32>
%4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
tf_device.return %4 : tensor<f32>
}) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32>
tf_device.return %1 : tensor<f32>
}) : () -> tensor<f32>
return %0 : tensor<f32>
}
```
}];
let constructor = "TFTPU::CreateTPUExtractOutsideCompilationPass()";
}

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
@ -53,39 +54,9 @@ constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
using OutsideClusterMap =
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>;
// This pass extracts a CPU computation cluster with `_xla_outside_compilation`
// annotation from a TPU cluster. Each outside compilation cluster is moved to
// a parallel_execute region. The TPU cluster is also moved to a
// parallel_execute region. Communication ops between device and host are
// added to pass inputs/outputs to/from the outside compiled region.
//
// A simple example:
// "tf_device.cluster"() ( {
// "tf.A"()
// "tf.B"() {_xla_outside_compilation = "cluster1"}
// "tf.C"()
// tf_device.return
// }) {num_cores_per_replica = 1, topology = "", device_assignment = []}
//
// Would become the following ops (unimportant attribute, type are omitted):
// "tf_device.parallel_execute"() ( {
// "tf_device.launch"() ( {
// "tf.B()
// tf_device.return
// })
// tf_device.return
// }, {
// "tf_device.cluster"( {
// "tf.A"()
// "tf.C"()
// tf_device.return
// })
// tf_device.return
// })
struct TPUExtractOutsideCompilation
: public PassWrapper<TPUExtractOutsideCompilation,
OperationPass<ModuleOp>> {
: public TF::TPUExtractOutsideCompilationPassBase<
TPUExtractOutsideCompilation> {
void runOnOperation() override;
};
@ -893,9 +864,5 @@ CreateTPUExtractOutsideCompilationPass() {
return std::make_unique<TPUExtractOutsideCompilation>();
}
static PassRegistration<TPUExtractOutsideCompilation> pass(
"tf-tpu-extract-outside-compilation",
"Extracts TPU outside compilation to separate parallel_execute.");
} // namespace TFTPU
} // namespace mlir