Add prepare-tpu-computation-for-tf-export pass

This pass transforms the MLIR module attached to _TpuCompileMlir op to be legal for export to TensorFlow GraphDef. Currently, this pass handles mhlo.sharding attribute on function attributes by passing them through TF XlaSharding ops and dropping tf.aliasing_output attribute.

This will be used for using the old tf2xla bridge for second phase of the compilation.

PiperOrigin-RevId: 361378341
Change-Id: Id418212e65f5d177bc4bb8c279824ef8bfc2add2
This commit is contained in:
Smit Hinsu 2021-03-06 21:20:12 -08:00 committed by TensorFlower Gardener
parent 99b1fa6a15
commit fb7d34de0a
5 changed files with 103 additions and 0 deletions

View File

@ -921,6 +921,7 @@ cc_library(
"transforms/optimize.cc",
"transforms/outside_compiled_to_host_launch.cc",
"transforms/parallel_execute_to_islands.cc",
"transforms/prepare_tpu_computation_for_tf_export.cc",
"transforms/promote_resources_to_args.cc",
"transforms/readonly_references_to_resources.cc",
"transforms/region_control_flow_to_functional.cc",

View File

@ -0,0 +1,22 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -prepare-tpu-computation-for-tf-export | FileCheck %s
// CHECK-LABEL: @main
func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) -> (tensor<10x1024xf32>, tensor<128x1024xf32>) {
// CHECK: %[[SHARDED_ARG0:.*]] = "tf.XlaSharding"(%arg0) {sharding = "\08\03\1A\02\01\02\22\02\00\01"}
// CHECK: %[[SHARDED_ARG1:.*]] = "tf.XlaSharding"(%arg1) {sharding = "\08\01\1A\01\01\22\01\00"}
// CHECK: "tf.Identity"(%[[SHARDED_ARG1]])
%0 = "tf.Identity"(%arg1) : (tensor<10x1024xf32>) -> tensor<10x1024xf32>
// CHECK: "tf.Identity"(%arg2)
%1 = "tf.Identity"(%arg2) : (tensor<128x1024xf32>) -> tensor<128x1024xf32>
return %0, %1 : tensor<10x1024xf32>, tensor<128x1024xf32>
}
// -----
// CHECK-NOT: tf.aliasing_output
func @main(%arg0: tensor<2xf32> {tf.aliasing_output = 0 : i64}) -> (tensor<2xf32>) {
return %arg0 : tensor<2xf32>
}

View File

@ -217,6 +217,10 @@ void AddGraphExportLoweringPasses(OpPassManager& pm);
// single op.
std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass();
// Returns pass that prepares TPU computation to be legal for export to
// TensorFlow.
std::unique_ptr<OperationPass<FuncOp>>
CreatePrepareTpuComputationForTfExportPass();
} // namespace TF
namespace tf_executor {

View File

@ -0,0 +1,64 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
namespace {
class PrepareTpuComputationForTfExportPass
: public PassWrapper<PrepareTpuComputationForTfExportPass, FunctionPass> {
void runOnFunction() override;
};
void PrepareTpuComputationForTfExportPass::runOnFunction() {
auto func = getFunction();
OpBuilder builder(func.getBody());
for (int i = 0; i < func.getNumArguments(); ++i) {
constexpr char kShardingAttr[] = "mhlo.sharding";
if (auto sharding =
func.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr)) {
if (!sharding.getValue().empty()) {
BlockArgument arg = func.getArgument(i);
auto updated_arg = builder.create<TF::XlaShardingOp>(
func.getLoc(), arg.getType(), arg, sharding, StringAttr());
func.getArgument(i).replaceAllUsesExcept(
updated_arg, llvm::SmallPtrSet<Operation*, 1>({updated_arg}));
}
func.removeArgAttr(i, builder.getIdentifier(kShardingAttr));
}
// TODO(prakalps, hinsu): Utilize aliasing output attribute instead of
// dropping it. This only affects performance and is not required for
// correctness.
constexpr char kAliasingAttr[] = "tf.aliasing_output";
func.removeArgAttr(i, builder.getIdentifier(kAliasingAttr));
}
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
CreatePrepareTpuComputationForTfExportPass() {
return std::make_unique<PrepareTpuComputationForTfExportPass>();
}
} // namespace TF
} // namespace mlir

View File

@ -811,3 +811,15 @@ def VerifySuitableForExportPass : Pass<"tf-verify-for-export", "ModuleOp"> {
let constructor = "TF::CreateVerifySuitableForExportPass()";
}
def PrepareTpuComputationForTfExportPass : Pass<"prepare-tpu-computation-for-tf-export", "FuncOp"> {
let summary = "Prepare TPU computation to be legal for export to TensorFlow";
let description = [{
Prepares TPU computation module attached to _TPUCompileMlir op for
TensorFlow graph export by making transformation such as replacing or
removing MLIR or XLA specific attributes that are not legal in TensorFlow
graph.
}];
let constructor = "TF::CreatePrepareTpuComputationForTfExportPass()";
}