Use declarative registration of TPURewritePass (NFC)

Update documentation for the TPURewritePass.

PiperOrigin-RevId: 349467996
Change-Id: Ide122623f51799a1109aa1354a441f864f00250b
This commit is contained in:
Roman Dzhabarov 2020-12-29 14:06:56 -08:00 committed by TensorFlower Gardener
parent 0366e4d699
commit 1f13bffc6c
2 changed files with 121 additions and 74 deletions

View File

@ -541,3 +541,122 @@ func @computation(%arg0: tensor<i32>) -> tensor<i32> {
let constructor = "TFTPU::CreateTPUResourceReadsWritesPartitioningPass()";
}
def TPURewritePass : Pass<"tf-tpu-rewrite", "ModuleOp"> {
let summary = "Rewrites a `tf_device.cluster_func` on TPUs into TPU runtime operations.";
let description = [{
This pass rewrites a `tf_device.cluster_func` operation into a sequence of `tf._TPUCompileMlir`
and `tf.TPUExecute` operations. `tf._TPUCompileMlir` contains a MLIR module that is
functionally equivalent to the function referenced by `tf_device.cluster_func`.
This makes the module to be jit-compiled and executed on TPU.
If it is not possible to rewrite the operation or device assignment fails,
a failure will be returned.
Note, many parameters to the `tf_device.cluster_func` are ommited in this
and following examples.
For example, a non replicated `tf_device.cluster_func`:
```mlir
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
%0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
return
}
```
will be rewritten as:
```mlir
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
%0:2 = "tf_device.launch"() ( {
%compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
%1 = "tf_device.launch"() ( {
%2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
tf_device.return %2 : tensor<i8>
}) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor<i8>
return
}
```
A replicated `tf_device.cluster_func`:
```mlir
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i8>) {n = 2 : i32} {
%1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
tf_device.return %1 : tensor<i8>
}
return
}
```
will be rewritten as:
```mlir
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
%0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor<i8>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} {
%1:2 = "tf_device.launch"() ( {
%compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
%2 = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
tf_device.return %3 : tensor<i8>
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i8>
tf_device.return %2 : tensor<i8>
}
return
}
A non replicated `tf_device.cluster_func` with the model parallelism:
```mlir
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}
```
will be rewritten as:
```mlir
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0:3 = "tf_device.launch"() ( {
%compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
tf_device.return %compilation_status, %program#0, %program#1 : tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
%1 = "tf_device.parallel_execute"() ( {
%2 = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf.string>) -> tensor<8xi32>
tf_device.return %3 : tensor<8xi32>
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32>
tf_device.return %2 : tensor<8xi32>
}, {
"tf_device.launch"() ( {
"tf.TPUExecute"(%0#2) : (tensor<3x!tf.string>) -> ()
tf_device.return
}) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> ()
tf_device.return
}) : () -> tensor<8xi32>
return %1 : tensor<8xi32>
}
```
}];
let constructor = "TFTPU::CreateTPURewritePass()";
}

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
@ -78,24 +79,8 @@ constexpr char kBadArrayElementMsg[] =
constexpr char kBadArrayAttrLengthMsg[] =
"bad '{0}' attribute, expected array attribute of size {1}, got size {2}";
// Rewrites `tf_device.cluster_func` operations assigned to TPU into actual TPU
// jit-compile runtime ops.
//
// For example:
// %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster", func =
// @tpu_func}
// %2 = "tf.SomeOp"(%1)
//
// Would become following ops (unimportant attributes, types are omitted):
// %1 = "tf.Shape"(%0)
// %2:2 = "tf._TPUCompileMlir"(%1) {module = "<Serialized @tpu_func>"}
// "tf.TPUCompileSucceededAssert"(%2#0)
// %3 = "tf.TPUExecute"(%0, %2#1)
// %4 = "tf.SomeOp"(%3)
namespace {
struct TPURewritePass
: public PassWrapper<TPURewritePass, OperationPass<ModuleOp>> {
struct TPURewritePass : public TF::TPURewritePassBase<TPURewritePass> {
void runOnOperation() override;
};
@ -578,59 +563,6 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
}
// Rewrites a `tf_device.cluster_func` operation into a set of TPU Runtime
// Operations that jit-compiles and executes function in
// `tf_device.cluster_func` on TPU. Device assignment is determined from
// available devices in `devices`. If it is not possible to rewrite the
// operation or device assignment fails, a failure will be returned.
//
// For example, a non replicated `tf_device.cluster_func`:
//
// func @main(%arg0: tensor<i1>) {
// %0 = "tf_device.cluster_func"(%arg0)
// {_tpu_replicate = "cluster0", device = "", func = @_func} :
// (tensor<i1>) -> tensor<i1>
// return
// }
//
// will be rewritten as:
//
// func @main(%arg0: tensor<i1>) {
// %0 = "tf.Shape"(%arg0) : (tensor<i1>) -> tensor<?xi32>
// %1:2 = "tf._TPUCompileMlir"(%0) {device = "/CPU:0"} :
// (tensor<?xi32>) -> (tensor<!tf.string>, tensor<2x!tf.string>)
// %2 = "tf.TPUExecute"(%arg0, %1#0) {device = "/TPU:0"} :
// (tensor<i1>, tensor<2x!tf.string>) -> tensor<i1>
// return
// }
//
// and a replicated `tf_device.cluster_func`:
//
// func @main(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
// {n = 2 : i32} {
// %1 = "tf_device.cluster_func"(%ri)
// {_tpu_replicate = "cluster0", device = "", func = @_func} :
// (tensor<i1>) -> tensor<i1>
// tf_device.return %1 : tensor<i1>
// }
// return
// }
//
// will be rewritten as:
//
// func @main(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
// {n = 2 : i32, devices = ["/TPU:0", "/TPU:1"]} {
// %1 = "tf.Shape"(%ri) : (tensor<i1>) -> tensor<?xi32>
// %2:2 = "tf._TPUCompileMlir"(%1) {device = "/CPU:0"} :
// (tensor<?xi32>) -> (tensor<!tf.string>, tensor<2x!tf.string>)
// %3 = "tf.TPUExecute"(%ri, %2#0) :
// (tensor<i1>, tensor<2x!tf.string>) -> tensor<i1>
// tf_device.return %3 : tensor<i1>
// }
// return
// }
LogicalResult Rewrite(
tf_device::ClusterFuncOp cluster_func,
llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
@ -831,9 +763,5 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
return std::make_unique<TPURewritePass>();
}
static PassRegistration<TPURewritePass> pass(
"tf-tpu-rewrite",
"Rewriting `tf_device.cluster_func` on TPUs into TPU runtime ops");
} // namespace TFTPU
} // namespace mlir