Add prepare-tpu-computation-for-tf-export pass
This pass transforms the MLIR module attached to _TpuCompileMlir op to be legal for export to TensorFlow GraphDef. Currently, this pass handles mhlo.sharding attribute on function attributes by passing them through TF XlaSharding ops and dropping tf.aliasing_output attribute. This will be used for using the old tf2xla bridge for second phase of the compilation. PiperOrigin-RevId: 361378341 Change-Id: Id418212e65f5d177bc4bb8c279824ef8bfc2add2
This commit is contained in:
parent
99b1fa6a15
commit
fb7d34de0a
@ -921,6 +921,7 @@ cc_library(
|
||||
"transforms/optimize.cc",
|
||||
"transforms/outside_compiled_to_host_launch.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
"transforms/prepare_tpu_computation_for_tf_export.cc",
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/readonly_references_to_resources.cc",
|
||||
"transforms/region_control_flow_to_functional.cc",
|
||||
|
@ -0,0 +1,22 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -prepare-tpu-computation-for-tf-export | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @main
|
||||
func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) -> (tensor<10x1024xf32>, tensor<128x1024xf32>) {
|
||||
|
||||
// CHECK: %[[SHARDED_ARG0:.*]] = "tf.XlaSharding"(%arg0) {sharding = "\08\03\1A\02\01\02\22\02\00\01"}
|
||||
// CHECK: %[[SHARDED_ARG1:.*]] = "tf.XlaSharding"(%arg1) {sharding = "\08\01\1A\01\01\22\01\00"}
|
||||
|
||||
// CHECK: "tf.Identity"(%[[SHARDED_ARG1]])
|
||||
%0 = "tf.Identity"(%arg1) : (tensor<10x1024xf32>) -> tensor<10x1024xf32>
|
||||
|
||||
// CHECK: "tf.Identity"(%arg2)
|
||||
%1 = "tf.Identity"(%arg2) : (tensor<128x1024xf32>) -> tensor<128x1024xf32>
|
||||
return %0, %1 : tensor<10x1024xf32>, tensor<128x1024xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-NOT: tf.aliasing_output
|
||||
func @main(%arg0: tensor<2xf32> {tf.aliasing_output = 0 : i64}) -> (tensor<2xf32>) {
|
||||
return %arg0 : tensor<2xf32>
|
||||
}
|
@ -217,6 +217,10 @@ void AddGraphExportLoweringPasses(OpPassManager& pm);
|
||||
// single op.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass();
|
||||
|
||||
// Returns pass that prepares TPU computation to be legal for export to
|
||||
// TensorFlow.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreatePrepareTpuComputationForTfExportPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace tf_executor {
|
||||
|
@ -0,0 +1,64 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
class PrepareTpuComputationForTfExportPass
|
||||
: public PassWrapper<PrepareTpuComputationForTfExportPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void PrepareTpuComputationForTfExportPass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
OpBuilder builder(func.getBody());
|
||||
for (int i = 0; i < func.getNumArguments(); ++i) {
|
||||
constexpr char kShardingAttr[] = "mhlo.sharding";
|
||||
if (auto sharding =
|
||||
func.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr)) {
|
||||
if (!sharding.getValue().empty()) {
|
||||
BlockArgument arg = func.getArgument(i);
|
||||
auto updated_arg = builder.create<TF::XlaShardingOp>(
|
||||
func.getLoc(), arg.getType(), arg, sharding, StringAttr());
|
||||
func.getArgument(i).replaceAllUsesExcept(
|
||||
updated_arg, llvm::SmallPtrSet<Operation*, 1>({updated_arg}));
|
||||
}
|
||||
|
||||
func.removeArgAttr(i, builder.getIdentifier(kShardingAttr));
|
||||
}
|
||||
|
||||
// TODO(prakalps, hinsu): Utilize aliasing output attribute instead of
|
||||
// dropping it. This only affects performance and is not required for
|
||||
// correctness.
|
||||
constexpr char kAliasingAttr[] = "tf.aliasing_output";
|
||||
func.removeArgAttr(i, builder.getIdentifier(kAliasingAttr));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreatePrepareTpuComputationForTfExportPass() {
|
||||
return std::make_unique<PrepareTpuComputationForTfExportPass>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -811,3 +811,15 @@ def VerifySuitableForExportPass : Pass<"tf-verify-for-export", "ModuleOp"> {
|
||||
|
||||
let constructor = "TF::CreateVerifySuitableForExportPass()";
|
||||
}
|
||||
|
||||
def PrepareTpuComputationForTfExportPass : Pass<"prepare-tpu-computation-for-tf-export", "FuncOp"> {
|
||||
let summary = "Prepare TPU computation to be legal for export to TensorFlow";
|
||||
let description = [{
|
||||
Prepares TPU computation module attached to _TPUCompileMlir op for
|
||||
TensorFlow graph export by making transformation such as replacing or
|
||||
removing MLIR or XLA specific attributes that are not legal in TensorFlow
|
||||
graph.
|
||||
}];
|
||||
|
||||
let constructor = "TF::CreatePrepareTpuComputationForTfExportPass()";
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user