diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index cd32dd9cd9d..1d21f5c79b8 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -301,7 +301,7 @@ StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
     return std::string("tf.If");
   }
   if (builtin_code == tflite::BuiltinOperator_WHILE) {
-    return std::string("tf.While");
+    return std::string("tfl.while");
   }
 
   llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
@@ -590,9 +590,7 @@ llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
     auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx));
 
     return {builder.getNamedAttr("cond", cond_attr),
-            builder.getNamedAttr("body", body_attr),
-            // TODO(b/139667752): Analyze statelessness correctly
-            builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
+            builder.getNamedAttr("body", body_attr)};
   }
   return {};
 }
@@ -714,6 +712,14 @@ StatusOr<Operation*> ConvertOp(
     TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
                                           builder));
   }
+  if (op_name == "tfl.while") {
+    // Adds two empty regions for "tfl.while". We will fill the regions after
+    // creating the callee functions because the "tfl.while" input/output types
+    // may be different with the callee functions, and the call ops need to sync
+    // with callee function types.
+    op_state.addRegion();
+    op_state.addRegion();
+  }
   if (op_name == "tfl.unidirectional_sequence_lstm") {
     TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
                                           builder));
@@ -755,7 +761,8 @@ StatusOr<Operation*> ConvertOp(
   }
   op_state.addAttributes(attrs);
 
-  // Handle the conversion from subgraph index to functions for If and While
+  // Handle the conversion from subgraph index to functions for If and While. We
+  // will add CallOps in the region to call the functions later for While.
   auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs(
       op.builtin_options, func_names, builder);
   op_state.addAttributes(function_ref_attrs);
@@ -1163,6 +1170,35 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
   }
   return subgraph.name;
 }
+
+// Adds a CallOp in `region` to call the `func` and returns the results of
+// CallOp.
+void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::FuncOp func) {
+  OpBuilder op_builder{region};
+  region.push_back(new mlir::Block());
+  region.addArguments(func.getType().getInputs());
+  op_builder.setInsertionPointToStart(&region.front());
+  auto call_op = op_builder.create<mlir::CallOp>(
+      region.getLoc(), func.getType().getResults(), func.sym_name(),
+      region.getArguments());
+  op_builder.create<mlir::TFL::YieldOp>(region.getLoc(), call_op.getResults());
+}
+
+// TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions
+// if we have while ops.
+void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
+  mlir::SymbolTable symbol_table(module);
+  module.walk([&](mlir::TFL::WhileOp while_op) {
+    auto cond = symbol_table.lookup<mlir::FuncOp>(
+        while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
+    AddCallOpInWhileOpRegion(while_op.cond(), cond);
+    while_op.removeAttr("cond");
+    auto body = symbol_table.lookup<mlir::FuncOp>(
+        while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
+    AddCallOpInWhileOpRegion(while_op.body(), body);
+    while_op.removeAttr("body");
+  });
+}
 }  // namespace
 
 OwningModuleRef tflite::FlatBufferToMlir(
@@ -1219,5 +1255,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
     }
     module.push_back(func_or_error.ConsumeValueOrDie());
   }
+  AddRegionsForTflWhileOp(module);
   return OwningModuleRef(module);
 }
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/unranked_function_output.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/unranked_function_output.mlir
index 8a97a83064f..ba9b52c54a2 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/unranked_function_output.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/unranked_function_output.mlir
@@ -2,12 +2,18 @@
 
 // This test is to test for unranked function output from input, the output type should be compatible with input type.
 
-// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32>
-// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
-// CHECK: return %0 : tensor<*xf32>
-// CHECK: func private @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
-// CHECK: func private @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
-
+// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tf.While"}} {
+// CHECK:   %0 = "tfl.while"(%arg0) ( {
+// CHECK:   ^bb0(%arg1: tensor<*xf32>):
+// CHECK:     %[[RES0:.*]] = call @cond(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
+// CHECK:     "tfl.yield"(%[[RES0]]) : (tensor<*xf32>) -> ()
+// CHECK:   },  {
+// CHECK:   ^bb0(%arg1: tensor<*xf32>):
+// CHECK:     %[[RES1:.*]] = call @body(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
+// CHECK:     "tfl.yield"(%[[RES1]]) : (tensor<*xf32>) -> ()
+// CHECK:   }) : (tensor<1xf32>) -> tensor<*xf32>
+// CHECK:   return %0 : tensor<*xf32>
+// CHECK: }
 func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> {
   %0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir
index d29c9e661ad..4728802b33f 100644
--- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir
@@ -1,8 +1,14 @@
 // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
 
-// Check to see if function references in while loops are preserved
-// TODO(b/138222071) Expect first output to be a scalar
-// CHECK:   %{{.*}}:2 = "tf.While"(%{{.*}}, %{{.*}}) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
+// Check to see if nested regions in while loops are preserved
+// CHECK:     %{{.*}}:2 = "tfl.while"(%{{.*}}, %{{.*}}) ( {
+// CHECK:     ^bb0(%{{.*}}: tensor<*xi32>, %{{.*}}: tensor<*xf32>):
+// CHECK:       "tfl.yield"(%{{.*}}) : (tensor<*xi1>) -> ()
+// CHECK:     },  {
+// CHECK:     ^bb0(%{{.*}}: tensor<*xi32>, %{{.*}}: tensor<*xf32>):
+// CHECK:       "tfl.yield"(%{{.*}}, %{{.*}}) : (tensor<*xi32>, tensor<*xf32>) -> ()
+// CHECK:     }) : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
+
 func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
   // While %arg0 is greater than zero, element wise add %arg1 with itself.
   %0:2 = "tfl.while"(%arg0, %arg1) ( {