diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 379c9c4dd09..26ecd5307f1 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -921,6 +921,7 @@ cc_library( "transforms/optimize.cc", "transforms/outside_compiled_to_host_launch.cc", "transforms/parallel_execute_to_islands.cc", + "transforms/prepare_tpu_computation_for_tf_export.cc", "transforms/promote_resources_to_args.cc", "transforms/readonly_references_to_resources.cc", "transforms/region_control_flow_to_functional.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir new file mode 100644 index 00000000000..bee99f529b7 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir @@ -0,0 +1,22 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -prepare-tpu-computation-for-tf-export | FileCheck %s + +// CHECK-LABEL: @main +func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) -> (tensor<10x1024xf32>, tensor<128x1024xf32>) { + + // CHECK: %[[SHARDED_ARG0:.*]] = "tf.XlaSharding"(%arg0) {sharding = "\08\03\1A\02\01\02\22\02\00\01"} + // CHECK: %[[SHARDED_ARG1:.*]] = "tf.XlaSharding"(%arg1) {sharding = "\08\01\1A\01\01\22\01\00"} + + // CHECK: "tf.Identity"(%[[SHARDED_ARG1]]) + %0 = "tf.Identity"(%arg1) : (tensor<10x1024xf32>) -> tensor<10x1024xf32> + + // CHECK: "tf.Identity"(%arg2) + %1 = "tf.Identity"(%arg2) : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + return %0, %1 : tensor<10x1024xf32>, tensor<128x1024xf32> +} + +// ----- + +// CHECK-NOT: tf.aliasing_output +func @main(%arg0: tensor<2xf32> {tf.aliasing_output = 0 : i64}) -> (tensor<2xf32>) { + return %arg0 : tensor<2xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index b5c489bab39..300257afdae 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -217,6 +217,10 @@ void AddGraphExportLoweringPasses(OpPassManager& pm); // single op. std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass(); +// Returns pass that prepares TPU computation to be legal for export to +// TensorFlow. +std::unique_ptr<OperationPass<FuncOp>> +CreatePrepareTpuComputationForTfExportPass(); } // namespace TF namespace tf_executor { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc new file mode 100644 index 00000000000..99baf4138c4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -0,0 +1,64 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { +namespace { + +class PrepareTpuComputationForTfExportPass + : public PassWrapper<PrepareTpuComputationForTfExportPass, FunctionPass> { + void runOnFunction() override; +}; + +void PrepareTpuComputationForTfExportPass::runOnFunction() { + auto func = getFunction(); + OpBuilder builder(func.getBody()); + for (int i = 0; i < func.getNumArguments(); ++i) { + constexpr char kShardingAttr[] = "mhlo.sharding"; + if (auto sharding = + func.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr)) { + if (!sharding.getValue().empty()) { + BlockArgument arg = func.getArgument(i); + auto updated_arg = builder.create<TF::XlaShardingOp>( + func.getLoc(), arg.getType(), arg, sharding, StringAttr()); + func.getArgument(i).replaceAllUsesExcept( + updated_arg, llvm::SmallPtrSet<Operation*, 1>({updated_arg})); + } + + func.removeArgAttr(i, builder.getIdentifier(kShardingAttr)); + } + + // TODO(prakalps, hinsu): Utilize aliasing output attribute instead of + // dropping it. This only affects performance and is not required for + // correctness. + constexpr char kAliasingAttr[] = "tf.aliasing_output"; + func.removeArgAttr(i, builder.getIdentifier(kAliasingAttr)); + } +} + +} // namespace + +std::unique_ptr<OperationPass<FuncOp>> +CreatePrepareTpuComputationForTfExportPass() { + return std::make_unique<PrepareTpuComputationForTfExportPass>(); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index e1b2092a831..6f3006c1da4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -811,3 +811,15 @@ def VerifySuitableForExportPass : Pass<"tf-verify-for-export", "ModuleOp"> { let constructor = "TF::CreateVerifySuitableForExportPass()"; } + +def PrepareTpuComputationForTfExportPass : Pass<"prepare-tpu-computation-for-tf-export", "FuncOp"> { + let summary = "Prepare TPU computation to be legal for export to TensorFlow"; + let description = [{ + Prepares TPU computation module attached to _TPUCompileMlir op for + TensorFlow graph export by making transformation such as replacing or + removing MLIR or XLA specific attributes that are not legal in TensorFlow + graph. + }]; + + let constructor = "TF::CreatePrepareTpuComputationForTfExportPass()"; +}