Migrate TPU resource read for write pass to use declarative pass registration instead of manually defined pass registration (NFC).

Pass documentation is also improved and migrated to the declarative pass spec.

PiperOrigin-RevId: 348812111
Change-Id: Ic23c33f59289c834f5461ebd28cf30332f7d8bd0
This commit is contained in:
Andy Ly 2020-12-23 10:09:00 -08:00 committed by TensorFlower Gardener
parent aaf94e8166
commit df37c3e1c8
2 changed files with 51 additions and 9 deletions

View File

@ -279,6 +279,52 @@ func @_func(%arg0: tensor<i32>) -> tensor<i32> {
let constructor = "TFDevice::CreateClusterOutliningPass()";
}
def TPUResourceReadForWritePass : Pass<"tf-tpu-resource-read-for-write", "ModuleOp"> {
let summary = "Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes with no reads";
let description = [{
This pass materializes `tf.ReadVariableOp` inputs to an outlined TPU computation
for resource variables where only writes are present so later in the pipeline
such resource variables can be fused with generated `tf.TPUExecute` ops, which
only supports resource variable read or read + write. For all TPU computations,
resource variables are required to be initialized prior to execution. Write only
resource variable uses can be generated currently via packed tensor uses.
For example, the following:
```mlir
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) {
%0 = "tf_device.cluster_func"(%value) {func = @cluster} : (tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @cluster(%arg0: tensor<i32>) -> tensor<i32> {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %identity : tensor<i32>
}
```
will be transformed into:
```mlir
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) {
%resource_read = "tf.ReadVariableOp"(%resource) : (tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32>
%0 = "tf_device.cluster_func"(%value, %resource_read) {func = @cluster} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @cluster(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %identity : tensor<i32>
}
```
}];
let constructor = "TFTPU::CreateTPUResourceReadForWritePass()";
}
def TPUExtractOutsideCompilationPass : Pass<"tf-tpu-extract-outside-compilation", "ModuleOp"> {
let summary = "Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.";

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
namespace mlir {
namespace TFTPU {
@ -32,8 +33,8 @@ namespace TFTPU {
// A pass that finds TPU clusters with write only resource access and adds an
// associated resource read, so the resource can later be fused into TPUExecute.
namespace {
struct TPUResourceReadForWrite
: public PassWrapper<TPUResourceReadForWrite, OperationPass<ModuleOp>> {
struct TPUResourceReadForWritePass
: public TF::TPUResourceReadForWritePassBase<TPUResourceReadForWritePass> {
void runOnOperation() override;
};
@ -78,7 +79,7 @@ bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
return false;
}
void TPUResourceReadForWrite::runOnOperation() {
void TPUResourceReadForWritePass::runOnOperation() {
SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
cluster_funcs.push_back(cluster_func);
@ -127,13 +128,8 @@ void TPUResourceReadForWrite::runOnOperation() {
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
return std::make_unique<TPUResourceReadForWrite>();
return std::make_unique<TPUResourceReadForWritePass>();
}
static PassRegistration<TPUResourceReadForWrite> pass(
"tf-tpu-resource-read-for-write",
"Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes "
"with no reads");
} // namespace TFTPU
} // namespace mlir