Migrate TPU resource read for write pass to use declarative pass registration instead of manually defined pass registration (NFC).
Pass documentation is also improved and migrated to the declarative pass spec. PiperOrigin-RevId: 348812111 Change-Id: Ic23c33f59289c834f5461ebd28cf30332f7d8bd0
This commit is contained in:
parent
aaf94e8166
commit
df37c3e1c8
@ -279,6 +279,52 @@ func @_func(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
let constructor = "TFDevice::CreateClusterOutliningPass()";
|
||||
}
|
||||
|
||||
def TPUResourceReadForWritePass : Pass<"tf-tpu-resource-read-for-write", "ModuleOp"> {
|
||||
let summary = "Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes with no reads";
|
||||
|
||||
let description = [{
|
||||
This pass materializes `tf.ReadVariableOp` inputs to an outlined TPU computation
|
||||
for resource variables where only writes are present so later in the pipeline
|
||||
such resource variables can be fused with generated `tf.TPUExecute` ops, which
|
||||
only supports resource variable read or read + write. For all TPU computations,
|
||||
resource variables are required to be initialized prior to execution. Write only
|
||||
resource variable uses can be generated currently via packed tensor uses.
|
||||
|
||||
For example, the following:
|
||||
|
||||
```mlir
|
||||
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) {
|
||||
%0 = "tf_device.cluster_func"(%value) {func = @cluster} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @cluster(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %identity : tensor<i32>
|
||||
}
|
||||
```
|
||||
|
||||
will be transformed into:
|
||||
|
||||
```mlir
|
||||
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) {
|
||||
%resource_read = "tf.ReadVariableOp"(%resource) : (tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32>
|
||||
%0 = "tf_device.cluster_func"(%value, %resource_read) {func = @cluster} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @cluster(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %identity : tensor<i32>
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let constructor = "TFTPU::CreateTPUResourceReadForWritePass()";
|
||||
}
|
||||
|
||||
def TPUExtractOutsideCompilationPass : Pass<"tf-tpu-extract-outside-compilation", "ModuleOp"> {
|
||||
let summary = "Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.";
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
@ -32,8 +33,8 @@ namespace TFTPU {
|
||||
// A pass that finds TPU clusters with write only resource access and adds an
|
||||
// associated resource read, so the resource can later be fused into TPUExecute.
|
||||
namespace {
|
||||
struct TPUResourceReadForWrite
|
||||
: public PassWrapper<TPUResourceReadForWrite, OperationPass<ModuleOp>> {
|
||||
struct TPUResourceReadForWritePass
|
||||
: public TF::TPUResourceReadForWritePassBase<TPUResourceReadForWritePass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
@ -78,7 +79,7 @@ bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
|
||||
return false;
|
||||
}
|
||||
|
||||
void TPUResourceReadForWrite::runOnOperation() {
|
||||
void TPUResourceReadForWritePass::runOnOperation() {
|
||||
SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
|
||||
getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
|
||||
cluster_funcs.push_back(cluster_func);
|
||||
@ -127,13 +128,8 @@ void TPUResourceReadForWrite::runOnOperation() {
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
|
||||
return std::make_unique<TPUResourceReadForWrite>();
|
||||
return std::make_unique<TPUResourceReadForWritePass>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUResourceReadForWrite> pass(
|
||||
"tf-tpu-resource-read-for-write",
|
||||
"Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes "
|
||||
"with no reads");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user