Use declarative registration of TPURewritePass (NFC)
Update documentation for the TPURewritePass. PiperOrigin-RevId: 349467996 Change-Id: Ide122623f51799a1109aa1354a441f864f00250b
This commit is contained in:
parent
0366e4d699
commit
1f13bffc6c
@ -541,3 +541,122 @@ func @computation(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
|
||||
let constructor = "TFTPU::CreateTPUResourceReadsWritesPartitioningPass()";
|
||||
}
|
||||
|
||||
def TPURewritePass : Pass<"tf-tpu-rewrite", "ModuleOp"> {
|
||||
let summary = "Rewrites a `tf_device.cluster_func` on TPUs into TPU runtime operations.";
|
||||
|
||||
let description = [{
|
||||
This pass rewrites a `tf_device.cluster_func` operation into a sequence of `tf._TPUCompileMlir`
|
||||
and `tf.TPUExecute` operations. `tf._TPUCompileMlir` contains a MLIR module that is
|
||||
functionally equivalent to the function referenced by `tf_device.cluster_func`.
|
||||
This makes the module to be jit-compiled and executed on TPU.
|
||||
If it is not possible to rewrite the operation or device assignment fails,
|
||||
a failure will be returned.
|
||||
|
||||
Note, many parameters to the `tf_device.cluster_func` are ommited in this
|
||||
and following examples.
|
||||
For example, a non replicated `tf_device.cluster_func`:
|
||||
|
||||
```mlir
|
||||
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
|
||||
%0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
|
||||
return
|
||||
}
|
||||
```
|
||||
|
||||
will be rewritten as:
|
||||
|
||||
```mlir
|
||||
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
|
||||
%0:2 = "tf_device.launch"() ( {
|
||||
%compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
|
||||
tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
|
||||
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
|
||||
%1 = "tf_device.launch"() ( {
|
||||
%2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
|
||||
tf_device.return %2 : tensor<i8>
|
||||
}) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor<i8>
|
||||
return
|
||||
}
|
||||
```
|
||||
|
||||
A replicated `tf_device.cluster_func`:
|
||||
|
||||
```mlir
|
||||
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
|
||||
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i8>) {n = 2 : i32} {
|
||||
%1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
|
||||
tf_device.return %1 : tensor<i8>
|
||||
}
|
||||
return
|
||||
}
|
||||
```
|
||||
|
||||
will be rewritten as:
|
||||
|
||||
```mlir
|
||||
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
|
||||
%0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor<i8>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} {
|
||||
%1:2 = "tf_device.launch"() ( {
|
||||
%compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
|
||||
tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
|
||||
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
|
||||
%2 = "tf_device.launch"() ( {
|
||||
%3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
|
||||
tf_device.return %3 : tensor<i8>
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i8>
|
||||
tf_device.return %2 : tensor<i8>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
A non replicated `tf_device.cluster_func` with the model parallelism:
|
||||
|
||||
```mlir
|
||||
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
|
||||
%0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32>
|
||||
return %0 : tensor<8xi32>
|
||||
}
|
||||
```
|
||||
|
||||
will be rewritten as:
|
||||
|
||||
```mlir
|
||||
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
|
||||
%0:3 = "tf_device.launch"() ( {
|
||||
%compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
|
||||
tf_device.return %compilation_status, %program#0, %program#1 : tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
|
||||
%1 = "tf_device.parallel_execute"() ( {
|
||||
%2 = "tf_device.launch"() ( {
|
||||
%3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf.string>) -> tensor<8xi32>
|
||||
tf_device.return %3 : tensor<8xi32>
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32>
|
||||
tf_device.return %2 : tensor<8xi32>
|
||||
}, {
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecute"(%0#2) : (tensor<3x!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> ()
|
||||
tf_device.return
|
||||
}) : () -> tensor<8xi32>
|
||||
return %1 : tensor<8xi32>
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let constructor = "TFTPU::CreateTPURewritePass()";
|
||||
}
|
||||
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
@ -78,24 +79,8 @@ constexpr char kBadArrayElementMsg[] =
|
||||
constexpr char kBadArrayAttrLengthMsg[] =
|
||||
"bad '{0}' attribute, expected array attribute of size {1}, got size {2}";
|
||||
|
||||
// Rewrites `tf_device.cluster_func` operations assigned to TPU into actual TPU
|
||||
// jit-compile runtime ops.
|
||||
//
|
||||
// For example:
|
||||
// %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster", func =
|
||||
// @tpu_func}
|
||||
// %2 = "tf.SomeOp"(%1)
|
||||
//
|
||||
// Would become following ops (unimportant attributes, types are omitted):
|
||||
// %1 = "tf.Shape"(%0)
|
||||
// %2:2 = "tf._TPUCompileMlir"(%1) {module = "<Serialized @tpu_func>"}
|
||||
// "tf.TPUCompileSucceededAssert"(%2#0)
|
||||
// %3 = "tf.TPUExecute"(%0, %2#1)
|
||||
// %4 = "tf.SomeOp"(%3)
|
||||
|
||||
namespace {
|
||||
struct TPURewritePass
|
||||
: public PassWrapper<TPURewritePass, OperationPass<ModuleOp>> {
|
||||
struct TPURewritePass : public TF::TPURewritePassBase<TPURewritePass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
@ -578,59 +563,6 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
|
||||
WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
|
||||
}
|
||||
|
||||
// Rewrites a `tf_device.cluster_func` operation into a set of TPU Runtime
|
||||
// Operations that jit-compiles and executes function in
|
||||
// `tf_device.cluster_func` on TPU. Device assignment is determined from
|
||||
// available devices in `devices`. If it is not possible to rewrite the
|
||||
// operation or device assignment fails, a failure will be returned.
|
||||
//
|
||||
// For example, a non replicated `tf_device.cluster_func`:
|
||||
//
|
||||
// func @main(%arg0: tensor<i1>) {
|
||||
// %0 = "tf_device.cluster_func"(%arg0)
|
||||
// {_tpu_replicate = "cluster0", device = "", func = @_func} :
|
||||
// (tensor<i1>) -> tensor<i1>
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// will be rewritten as:
|
||||
//
|
||||
// func @main(%arg0: tensor<i1>) {
|
||||
// %0 = "tf.Shape"(%arg0) : (tensor<i1>) -> tensor<?xi32>
|
||||
// %1:2 = "tf._TPUCompileMlir"(%0) {device = "/CPU:0"} :
|
||||
// (tensor<?xi32>) -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
// %2 = "tf.TPUExecute"(%arg0, %1#0) {device = "/TPU:0"} :
|
||||
// (tensor<i1>, tensor<2x!tf.string>) -> tensor<i1>
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// and a replicated `tf_device.cluster_func`:
|
||||
//
|
||||
// func @main(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
// %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
|
||||
// {n = 2 : i32} {
|
||||
// %1 = "tf_device.cluster_func"(%ri)
|
||||
// {_tpu_replicate = "cluster0", device = "", func = @_func} :
|
||||
// (tensor<i1>) -> tensor<i1>
|
||||
// tf_device.return %1 : tensor<i1>
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// will be rewritten as:
|
||||
//
|
||||
// func @main(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
// %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
|
||||
// {n = 2 : i32, devices = ["/TPU:0", "/TPU:1"]} {
|
||||
// %1 = "tf.Shape"(%ri) : (tensor<i1>) -> tensor<?xi32>
|
||||
// %2:2 = "tf._TPUCompileMlir"(%1) {device = "/CPU:0"} :
|
||||
// (tensor<?xi32>) -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
// %3 = "tf.TPUExecute"(%ri, %2#0) :
|
||||
// (tensor<i1>, tensor<2x!tf.string>) -> tensor<i1>
|
||||
// tf_device.return %3 : tensor<i1>
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
LogicalResult Rewrite(
|
||||
tf_device::ClusterFuncOp cluster_func,
|
||||
llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
|
||||
@ -831,9 +763,5 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass() {
|
||||
return std::make_unique<TPURewritePass>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPURewritePass> pass(
|
||||
"tf-tpu-rewrite",
|
||||
"Rewriting `tf_device.cluster_func` on TPUs into TPU runtime ops");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
||||
|
Loading…
x
Reference in New Issue
Block a user