diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 379c9c4dd09..26ecd5307f1 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -921,6 +921,7 @@ cc_library(
         "transforms/optimize.cc",
         "transforms/outside_compiled_to_host_launch.cc",
         "transforms/parallel_execute_to_islands.cc",
+        "transforms/prepare_tpu_computation_for_tf_export.cc",
         "transforms/promote_resources_to_args.cc",
         "transforms/readonly_references_to_resources.cc",
         "transforms/region_control_flow_to_functional.cc",
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir
new file mode 100644
index 00000000000..bee99f529b7
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/prepare_tpu_computation_for_tf_export.mlir
@@ -0,0 +1,22 @@
+// RUN: tf-opt %s -split-input-file -verify-diagnostics -prepare-tpu-computation-for-tf-export | FileCheck %s
+
+// CHECK-LABEL: @main
+func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) -> (tensor<10x1024xf32>, tensor<128x1024xf32>) {
+
+  // CHECK: %[[SHARDED_ARG0:.*]] = "tf.XlaSharding"(%arg0) {sharding = "\08\03\1A\02\01\02\22\02\00\01"}
+  // CHECK: %[[SHARDED_ARG1:.*]] = "tf.XlaSharding"(%arg1) {sharding = "\08\01\1A\01\01\22\01\00"}
+
+  // CHECK: "tf.Identity"(%[[SHARDED_ARG1]])
+  %0 = "tf.Identity"(%arg1) : (tensor<10x1024xf32>) -> tensor<10x1024xf32>
+
+  // CHECK: "tf.Identity"(%arg2)
+  %1 = "tf.Identity"(%arg2) : (tensor<128x1024xf32>) -> tensor<128x1024xf32>
+  return %0, %1 : tensor<10x1024xf32>, tensor<128x1024xf32>
+}
+
+// -----
+
+// CHECK-NOT: tf.aliasing_output
+func @main(%arg0: tensor<2xf32> {tf.aliasing_output = 0 : i64}) -> (tensor<2xf32>) {
+  return %arg0 : tensor<2xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index b5c489bab39..300257afdae 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -217,6 +217,10 @@ void AddGraphExportLoweringPasses(OpPassManager& pm);
 // single op.
 std::unique_ptr<OperationPass<ModuleOp>> CreateVerifySuitableForExportPass();
 
+// Returns pass that prepares TPU computation to be legal for export to
+// TensorFlow.
+std::unique_ptr<OperationPass<FuncOp>>
+CreatePrepareTpuComputationForTfExportPass();
 }  // namespace TF
 
 namespace tf_executor {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc
new file mode 100644
index 00000000000..99baf4138c4
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc
@@ -0,0 +1,64 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+
+namespace mlir {
+namespace TF {
+namespace {
+
+class PrepareTpuComputationForTfExportPass
+    : public PassWrapper<PrepareTpuComputationForTfExportPass, FunctionPass> {
+  void runOnFunction() override;
+};
+
+void PrepareTpuComputationForTfExportPass::runOnFunction() {
+  auto func = getFunction();
+  OpBuilder builder(func.getBody());
+  for (int i = 0; i < func.getNumArguments(); ++i) {
+    constexpr char kShardingAttr[] = "mhlo.sharding";
+    if (auto sharding =
+            func.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr)) {
+      if (!sharding.getValue().empty()) {
+        BlockArgument arg = func.getArgument(i);
+        auto updated_arg = builder.create<TF::XlaShardingOp>(
+            func.getLoc(), arg.getType(), arg, sharding, StringAttr());
+        func.getArgument(i).replaceAllUsesExcept(
+            updated_arg, llvm::SmallPtrSet<Operation*, 1>({updated_arg}));
+      }
+
+      func.removeArgAttr(i, builder.getIdentifier(kShardingAttr));
+    }
+
+    // TODO(prakalps, hinsu): Utilize aliasing output attribute instead of
+    // dropping it. This only affects performance and is not required for
+    // correctness.
+    constexpr char kAliasingAttr[] = "tf.aliasing_output";
+    func.removeArgAttr(i, builder.getIdentifier(kAliasingAttr));
+  }
+}
+
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+CreatePrepareTpuComputationForTfExportPass() {
+  return std::make_unique<PrepareTpuComputationForTfExportPass>();
+}
+
+}  // namespace TF
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
index e1b2092a831..6f3006c1da4 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td
@@ -811,3 +811,15 @@ def VerifySuitableForExportPass : Pass<"tf-verify-for-export", "ModuleOp"> {
 
   let constructor = "TF::CreateVerifySuitableForExportPass()";
 }
+
+def PrepareTpuComputationForTfExportPass : Pass<"prepare-tpu-computation-for-tf-export", "FuncOp"> {
+  let summary = "Prepare TPU computation to be legal for export to TensorFlow";
+  let description = [{
+    Prepares TPU computation module attached to _TPUCompileMlir op for
+    TensorFlow graph export by making transformation such as replacing or
+    removing MLIR or XLA specific attributes that are not legal in TensorFlow
+    graph.
+  }];
+
+  let constructor = "TF::CreatePrepareTpuComputationForTfExportPass()";
+}