diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 2daa8a86d37..d122790af07 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -354,6 +354,7 @@ cc_library(
         ":mhlo_to_lhlo_with_xla",
         ":mlir_hlo_to_hlo",
         ":translate_cl_options",
+        ":type_to_shape",
         "//tensorflow/compiler/jit:xla_cpu_jit",
         "//tensorflow/compiler/jit:xla_gpu_jit",
         "//tensorflow/compiler/mlir/hlo",
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 5c7a592df27..36aa31b0d34 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -497,11 +497,12 @@ class ConvertToHloModule {
   // Multiple return values are always converted to a tuple and returned as a
   // single value.
   explicit ConvertToHloModule(
-      mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
+      mlir::ModuleOp module, xla::XlaBuilder& module_builder,
+      bool use_tuple_args, bool return_tuple,
       tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
       MlirToHloConversionOptions options)
       : module_(module),
-        module_builder_("main"),
+        module_builder_(module_builder),
         use_tuple_args_(use_tuple_args),
         return_tuple_(return_tuple),
         shape_representation_fn_(shape_representation_fn),
@@ -547,14 +548,14 @@ class ConvertToHloModule {
       mlir::CallOp call_op, xla::XlaBuilder* builder,
       ConvertToHloModule::ValueLoweringMap* value_lowering);
 
- private:
   LogicalResult Lower(
       mlir::Operation* inst, bool is_entry_function,
       llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
       xla::XlaBuilder* builder,
       ConvertToHloModule::ValueLoweringMap* value_lowering,
-      xla::XlaComputation* result);
+      xla::XlaOp* return_value);
 
+ private:
   LogicalResult SetEntryTupleShapesAndLeafReplication(
       Block* block, const std::vector<bool>& entry_args_same_across_replicas,
       llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
@@ -569,7 +570,7 @@ class ConvertToHloModule {
   mlir::ModuleOp module_;
 
   // The top-level XlaBuilder.
-  xla::XlaBuilder module_builder_;
+  xla::XlaBuilder& module_builder_;
 
   // Map between function and lowered computation.
   FunctionLoweringMap lowered_computation_;
@@ -1189,7 +1190,9 @@ LogicalResult ConvertToHloModule::Lower(
     llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
     xla::XlaBuilder* builder,
     ConvertToHloModule::ValueLoweringMap* value_lowering,
-    xla::XlaComputation* result) {
+    xla::XlaOp* return_value) {
+  *return_value = xla::XlaOp();
+
   // See MlirToHloConversionOptions for more about layouts.
   auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) {
     if (options_.propagate_layouts) {
@@ -1255,7 +1258,6 @@ LogicalResult ConvertToHloModule::Lower(
   if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
     // Construct the return value for the function. If there are multiple
     // values returned, then create a tuple, else return value directly.
-    xla::XlaOp return_value;
     unsigned num_return_values = inst->getNumOperands();
     if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
       const bool has_ret_shardings =
@@ -1291,24 +1293,16 @@ LogicalResult ConvertToHloModule::Lower(
         builder->SetSharding(sharding);
       }
 
-      return_value = xla::Tuple(builder, returns);
+      *return_value = xla::Tuple(builder, returns);
       builder->ClearSharding();
     } else if (num_return_values == 1) {
       xla::XlaOp operand;
       if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
         return failure();
 
-      return_value = operand;
+      *return_value = operand;
     }
 
-    // Build the XlaComputation and check for failures.
-    auto computation_or =
-        return_value.valid() ? builder->Build(return_value) : builder->Build();
-    if (!computation_or.ok()) {
-      inst->emitError(llvm::Twine(computation_or.status().error_message()));
-      return failure();
-    }
-    *result = std::move(computation_or.ValueOrDie());
     return success();
   }
 
@@ -1515,11 +1509,21 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
     }
   }
 
+  xla::XlaOp return_value;
   for (auto& inst : *block)
     if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
-                     &lowering, result)))
+                     &lowering, &return_value)))
       return failure();
 
+  // Build the XlaComputation and check for failures.
+  auto computation_or =
+      return_value.valid() ? builder->Build(return_value) : builder->Build();
+  if (!computation_or.ok()) {
+    block->back().emitError(
+        llvm::Twine(computation_or.status().error_message()));
+    return failure();
+  }
+  *result = std::move(computation_or.ValueOrDie());
   return success();
 }
 
@@ -1704,7 +1708,8 @@ Status ConvertRegionToComputation(mlir::Region* region,
                                   xla::XlaComputation* func,
                                   MlirToHloConversionOptions options) {
   mlir::ModuleOp module;
-  ConvertToHloModule converter(module, true, true, {}, options);
+  xla::XlaBuilder module_builder("main");
+  ConvertToHloModule converter(module, module_builder, true, true, {}, options);
   if (failed(converter.LowerRegionAsComputation(region, func)))
     return tensorflow::errors::Internal(
         "failed to convert region to computation");
@@ -1717,14 +1722,55 @@ Status ConvertMlirHloToHlo(
     const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
     MlirToHloConversionOptions options) {
   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
-  ConvertToHloModule converter(module, use_tuple_args, return_tuple,
-                               shape_representation_fn, options);
+  xla::XlaBuilder module_builder("main");
+  ConvertToHloModule converter(module, module_builder, use_tuple_args,
+                               return_tuple, shape_representation_fn, options);
   if (failed(converter.Run())) return diag_handler.ConsumeStatus();
   auto hlo_module = converter.ConsumeMainProto();
   hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
   if (failed(AddDynamicParameterBindings(
           module, hlo_proto->mutable_hlo_module(), use_tuple_args)))
     return diag_handler.ConsumeStatus();
+  return Status::OK();
+}
+
+Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
+                           llvm::ArrayRef<xla::XlaOp> xla_params,
+                           std::vector<xla::XlaOp>& returns,
+                           MlirToHloConversionOptions options) {
+  auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
+  ConvertToHloModule converter(module, builder,
+                               /*use_tuple_args=*/false, /*return_tuple=*/false,
+                               /*shape_representation_fn=*/nullptr, options);
+
+  ConvertToHloModule::ValueLoweringMap lowering;
+  if (xla_params.size() != block.getArguments().size())
+    return tensorflow::errors::Internal(
+        "xla_params size != block arguments size");
+  for (BlockArgument& arg : block.getArguments()) {
+    auto num = arg.getArgNumber();
+    lowering[arg] = xla_params[num];
+  }
+
+  mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
+  for (auto& inst : block) {
+    if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
+      returns.resize(inst.getNumOperands());
+      for (OpOperand& ret : inst.getOpOperands()) {
+        unsigned index = ret.getOperandNumber();
+        xla::XlaOp operand;
+        if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
+          return diag_handler.ConsumeStatus();
+        returns[index] = operand;
+      }
+    } else {
+      xla::XlaOp return_value;
+      if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
+                                 /*ret_shardings=*/{}, &builder, &lowering,
+                                 &return_value)))
+        return diag_handler.ConsumeStatus();
+    }
+  }
 
   return Status::OK();
 }
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
index a260a797354..a1c1cb5c7da 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
@@ -52,6 +52,14 @@ Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
                                shape_representation_fn = nullptr,
                            MlirToHloConversionOptions options = {});
 
+// Transforms a Block into HLO, where the HLO is represented as calls into an
+// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp.
+// xla_params are inputs to block. returns are the returned XlaOps.
+Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
+                           llvm::ArrayRef<xla::XlaOp> xla_params,
+                           std::vector<xla::XlaOp>& returns,
+                           MlirToHloConversionOptions options = {});
+
 // Converts a region to a computation. It returns a standalone module that
 // contains the converted region as the entry computation.
 Status ConvertRegionToComputation(mlir::Region* region,
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
index 61686e13b26..b3d3603ae41 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
@@ -1,4 +1,5 @@
 // RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
+// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text-via-builder %s | FileCheck %s
 
 // CHECK:  HloModule
 func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token {
@@ -1004,22 +1005,6 @@ func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
 
 // -----
 
-// Tests that the exported HLO module keeps parameter replication annotation.
-
-// CHECK:  HloModule
-func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
-  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
-  return %0 : tensor<16x16xf32>
-}
-
-// CHECK:  ENTRY
-// CHECK:  %[[ARG0:.*]] = f32[16,16] parameter(0)
-// CHECK-NOT: parameter_replication={true}
-// CHECK:  %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true}
-// CHECK:  ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]])
-
-// -----
-
 // CHECK:  HloModule
 func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (tensor<2xf32>, tensor<2xf64>) {
   %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export_replicas.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export_replicas.mlir
new file mode 100644
index 00000000000..40012f18c71
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export_replicas.mlir
@@ -0,0 +1,15 @@
+// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
+
+// Tests that the exported HLO module keeps parameter replication annotation.
+
+// CHECK:  HloModule
+func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
+  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
+  return %0 : tensor<16x16xf32>
+}
+
+// CHECK:  ENTRY
+// CHECK:  %[[ARG0:.*]] = f32[16,16] parameter(0)
+// CHECK-NOT: parameter_replication={true}
+// CHECK:  %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true}
+// CHECK:  ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]])
diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc
index 1be19de10c0..cc8c23ca124 100644
--- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc
+++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc
@@ -25,6 +25,7 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
+#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
 #include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h"
 #include "tensorflow/compiler/xla/debug_options_flags.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -124,16 +125,58 @@ static StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
   return HloModule::CreateFromProto(module_proto, module_config);
 }
 
+// Wraps BuildHloFromMlirHlo to output an HloProto that's the same as
+// ConvertMlirHloToHlo.
+Status ConvertMlirHloToHloViaBuilder(mlir::ModuleOp module,
+                                     ::xla::HloProto* hlo_proto,
+                                     mlir::MlirToHloConversionOptions options) {
+  mlir::FuncOp main = module.lookupSymbol<mlir::FuncOp>("main");
+  mlir::Block& block = main.getRegion().front();
+  xla::XlaBuilder builder("main");
+
+  // Create xla_params.
+  std::vector<xla::XlaOp> xla_params;
+  for (mlir::BlockArgument& arg : block.getArguments()) {
+    auto num = arg.getArgNumber();
+    xla::Shape shape = xla::TypeToShape(arg.getType());
+    XlaOp argop =
+        xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
+    xla_params.push_back(argop);
+  }
+
+  std::vector<xla::XlaOp> returns(1);
+  TF_RETURN_IF_ERROR(
+      mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns, options));
+
+  xla::XlaOp return_value;
+  if (returns.size() == 1)
+    return_value = returns[0];
+  else if (returns.size() > 1)
+    return_value = xla::Tuple(&builder, returns);
+
+  TF_ASSIGN_OR_RETURN(
+      xla::XlaComputation computation,
+      return_value.valid() ? builder.Build(return_value) : builder.Build());
+  auto hlo_module = computation.proto();
+  hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
+
+  return Status::OK();
+}
+
 static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
-    mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts) {
+    mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts,
+    bool via_builder) {
   if (!module) return mlir::failure();
 
   HloProto hloProto;
   mlir::MlirToHloConversionOptions options;
   options.propagate_layouts = with_layouts;
-  Status status = mlir::ConvertMlirHloToHlo(
-      module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
-      /*shape_representation_fn=*/nullptr, options);
+  Status status =
+      via_builder
+          ? ConvertMlirHloToHloViaBuilder(module, &hloProto, options)
+          : mlir::ConvertMlirHloToHlo(
+                module, &hloProto, emit_use_tuple_arg, emit_return_tuple,
+                /*shape_representation_fn=*/nullptr, options);
   if (!status.ok()) {
     LOG(ERROR) << "Module conversion failed: " << status;
     return mlir::failure();
@@ -167,13 +210,24 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
 static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
     mlir::ModuleOp module, llvm::raw_ostream& output) {
   return MlirHloToHloTextTranslateFunctionImpl(module, output,
-                                               /*with_layouts=*/false);
+                                               /*with_layouts=*/false,
+                                               /*via_builder=*/false);
 }
 
 static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction(
     mlir::ModuleOp module, llvm::raw_ostream& output) {
   return MlirHloToHloTextTranslateFunctionImpl(module, output,
-                                               /*with_layouts=*/true);
+                                               /*with_layouts=*/true,
+                                               /*via_builder=*/false);
+}
+
+// This converts MlirHlo to Hlo by first converting to XlaBuilder.
+// This is useful for testing conversion to XlaBuilder.
+static mlir::LogicalResult MlirHloToHloTextViaBuilderTranslateFunction(
+    mlir::ModuleOp module, llvm::raw_ostream& output) {
+  return MlirHloToHloTextTranslateFunctionImpl(module, output,
+                                               /*with_layouts=*/false,
+                                               /*via_builder=*/true);
 }
 
 }  // namespace xla
@@ -194,6 +248,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextWithLayoutsTranslate(
     "mlir-hlo-to-hlo-text-with-layouts",
     xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects);
 
+static mlir::TranslateFromMLIRRegistration MlirHloToHloTextViaBuilderTranslate(
+    "mlir-hlo-to-hlo-text-via-builder",
+    xla::MlirHloToHloTextViaBuilderTranslateFunction, RegisterInputDialects);
+
 static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
     "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);