diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
index 976ff91f6ce..1ea0e797675 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
@@ -45,22 +45,32 @@ namespace tensorflow {
 namespace {
 
 // GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
-// may carry and returns them in a 2D vector: int64[ElementSize][DimSize]. If a
-// dimension is static, a constant dimension is returned.
+// may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a
+// dimension is static, a constant dimension is returned. If a dim is dynamic, a
+// dynamic XlaOp representing the dynamic size is returned.
 xla::StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
     XlaOpKernelContext* ctx, const xla::Shape& element_shape,
     const xla::Shape& list_shape, int64 num_elements) {
   std::vector<int64> dynamic_sizes;
-  ctx->set_dynamic_dimension_is_minus_one(true);
   // The multiplier can be a dynamic value.
   TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
+  std::vector<bool> dims_are_dynamic;
+  TF_RETURN_IF_ERROR(
+      ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic));
+  bool leading_dim_is_dynamic;
+  TF_RETURN_IF_ERROR(
+      ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic));
   std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
   // Set dynamic dimension size to 0 for initialization value.
   std::vector<xla::XlaOp> dynamic_dims;
-  // Leading dim is a static dimension.
-  dynamic_dims.push_back(xla::ConstantR0<int32>(ctx->builder(), num_elements));
+  if (leading_dim_is_dynamic) {
+    dynamic_dims.push_back(ctx->Input(1));
+  } else {
+    dynamic_dims.push_back(
+        xla::ConstantR0<int32>(ctx->builder(), num_elements));
+  }
   for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) {
-    if (ctx->is_dynamic_dimension(dynamic_sizes[dim])) {
+    if (dims_are_dynamic[dim]) {
       auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
       dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
       dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
@@ -80,11 +90,12 @@ class TensorListLengthOp : public XlaOpKernel {
 
   void Compile(XlaOpKernelContext* ctx) override {
     int64 leading_dim;
-    OP_REQUIRES_OK(ctx,
-                   GetLeadingDimForTensorList(ctx->Input(0), &leading_dim));
-    Tensor length_tensor(DT_INT32, {});
-    length_tensor.scalar<int32>()() = static_cast<int32>(leading_dim);
-    ctx->SetConstantOutput(0, length_tensor);
+    xla::XlaOp leading_dim_size;
+    bool leading_dim_is_dynamic;
+    OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim,
+                                                   &leading_dim_is_dynamic,
+                                                   &leading_dim_size));
+    ctx->SetOutput(0, leading_dim_size);
   }
 
  private:
@@ -134,6 +145,9 @@ class TensorListReserveOp : public XlaOpKernel {
   void Compile(XlaOpKernelContext* ctx) override {
     int64 num_elements;
     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
+    bool num_element_is_dynamic;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
     OP_REQUIRES(
         ctx, num_elements >= 0,
         errors::InvalidArgument(
@@ -156,7 +170,8 @@ class TensorListReserveOp : public XlaOpKernel {
     if (got_shape) {
       xla::Shape list_shape;
       OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
-                              element_shape, num_elements, &list_shape));
+                              element_shape, num_elements,
+                              num_element_is_dynamic, &list_shape));
       // Set up dynamic dimension sizes to create the zero tensor.
       auto list_dynamic_dims_or = GetTensorListDynamicDims(
           ctx, element_shape, list_shape, num_elements);
@@ -175,8 +190,8 @@ class TensorListReserveOp : public XlaOpKernel {
       return;
     }
 
-    xla::XlaOp result =
-        BuildUninitializedTensorList(ctx->builder(), num_elements);
+    xla::XlaOp result = BuildUninitializedTensorList(
+        ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1));
     ctx->SetTensorListOutput(0, result);
   }
 
@@ -200,6 +215,9 @@ class EmptyTensorListOp : public XlaOpKernel {
   void Compile(XlaOpKernelContext* ctx) override {
     int64 max_num_elements;
     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
+    bool num_element_is_dynamic;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
     OP_REQUIRES(ctx, max_num_elements >= 0,
                 errors::InvalidArgument(
                     "XLA compilation requires a fixed tensor list size. Set "
@@ -210,9 +228,9 @@ class EmptyTensorListOp : public XlaOpKernel {
 
     if (dtype_ != DT_VARIANT) {
       // We are creating a non-nested TensorList.
-      // If element shape is compile time constant and it's not "unknown rank"
-      // shape (-1), create an initialized TensorList. Otherwise create an
-      // uninitialized TensorList.
+      // If element shape is compile time constant and it's not "unknown
+      // rank" shape (-1), create an initialized TensorList. Otherwise
+      // create an uninitialized TensorList.
       xla::XlaOp element_shape_handle = ctx->Input(0);
       xla::PrimitiveType type;
       OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
@@ -224,7 +242,8 @@ class EmptyTensorListOp : public XlaOpKernel {
       if (got_shape) {
         xla::Shape list_shape;
         OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
-                                element_shape, max_num_elements, &list_shape));
+                                element_shape, max_num_elements,
+                                num_element_is_dynamic, &list_shape));
         // Set up dynamic dimension sizes to create the zero tensor.
         auto list_dynamic_dims_or = GetTensorListDynamicDims(
             ctx, element_shape, list_shape, max_num_elements);
@@ -243,7 +262,8 @@ class EmptyTensorListOp : public XlaOpKernel {
     // We are creating a nested TensorList or a non-nested TensorList with
     // unknown shape. Just create an uninitialized TensorList.
     xla::XlaOp result =
-        BuildUninitializedTensorList(ctx->builder(), max_num_elements);
+        BuildUninitializedTensorList(ctx->builder(), max_num_elements,
+                                     num_element_is_dynamic, ctx->Input(1));
     ctx->SetTensorListOutput(0, result);
   }
 
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc
index 0e367e10ec4..156f9bfea40 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc
@@ -189,28 +189,42 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
 }
 
 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
-                                        int64 leading_dimension) {
+                                        int64 leading_dimension,
+                                        bool leading_size_is_dynamic,
+                                        xla::XlaOp leading_dim_size) {
   auto zero =
       xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
-  return xla::Broadcast(zero, std::vector<int64>{leading_dimension});
+  auto broadcast = xla::Broadcast(zero, std::vector<int64>{leading_dimension});
+  if (leading_size_is_dynamic) {
+    return xla::SetDimensionSize(broadcast, leading_dim_size, 0);
+  } else {
+    return broadcast;
+  }
 }
 
-Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim) {
+Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim,
+                                  bool* leading_dim_is_dynamic,
+                                  xla::XlaOp* leading_dim_dynamic_size) {
   bool is_initialized;
   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
   if (is_initialized) {
     auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
+    *leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0);
+    auto buffer = xla::GetTupleElement(list, 0);
     *leading_dim = buffer_shape.dimensions(0);
+    *leading_dim_dynamic_size = xla::GetDimensionSize(buffer, 0);
   } else {
+    *leading_dim_is_dynamic = list_shape.is_dynamic_dimension(0);
     *leading_dim = list_shape.dimensions(0);
+    *leading_dim_dynamic_size = xla::GetDimensionSize(list, 0);
   }
   return Status::OK();
 }
 
 Status GetTensorListShapeFromElementTensorListShape(
     const xla::Shape& element_tensor_list_shape, int64 leading_dim,
-    xla::Shape* tensor_list_shape) {
+    bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) {
   std::vector<xla::Shape> shapes;
   int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape);
   for (int i = 0; i < tuple_size; i++) {
@@ -220,6 +234,9 @@ Status GetTensorListShapeFromElementTensorListShape(
     dimensions.insert(dimensions.begin(), leading_dim);
     shapes.push_back(
         xla::ShapeUtil::MakeShape(shape.element_type(), dimensions));
+    if (leading_dim_is_dynamic) {
+      shapes.back().set_dynamic_dimension(0, true);
+    }
   }
   shapes.push_back(
       xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
@@ -229,6 +246,7 @@ Status GetTensorListShapeFromElementTensorListShape(
 
 Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
                                           int64 leading_dim,
+                                          bool leading_dim_is_dynamic,
                                           xla::Shape* tensor_list_shape) {
   if (!element_shape.IsArray()) {
     return errors::InvalidArgument(
@@ -236,12 +254,12 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
         "shape. But element shape is ",
         element_shape.DebugString());
   }
-
   std::vector<xla::Shape> shapes;
   std::vector<int64> dimensions = xla::SpanToVector(element_shape.dimensions());
   dimensions.insert(dimensions.begin(), leading_dim);
   shapes.push_back(
       xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions));
+  shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic);
   shapes.push_back(
       xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
   *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
@@ -279,7 +297,10 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
                                           bool element_is_tensor_list,
                                           xla::XlaOp* initialized_list) {
   int64 leading_dim;
-  TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(list, &leading_dim));
+  xla::XlaOp leading_dim_dynamic_size;
+  bool leading_dim_is_dynamic;
+  TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(
+      list, &leading_dim, &leading_dim_is_dynamic, &leading_dim_dynamic_size));
 
   xla::XlaBuilder* b = list.builder();
   xla::Shape list_shape;
@@ -287,12 +308,11 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
 
   if (element_is_tensor_list) {
     TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
-        element_shape, leading_dim, &list_shape));
+        element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
   } else {
     TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
-        element_shape, leading_dim, &list_shape));
+        element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
   }
-
   bool is_initialized;
   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
   if (is_initialized) {
@@ -312,8 +332,7 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
     for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
       std::vector<xla::XlaOp> dynamic_dims;
       const xla::Shape& shape = list_shape.tuple_shapes(i);
-      // Leading dim is a static dimension.
-      dynamic_dims.push_back(xla::ConstantR0<int32>(b, leading_dim));
+      dynamic_dims.push_back(leading_dim_dynamic_size);
       xla::XlaOp sub_element;
       if (element_is_tensor_list) {
         sub_element = xla::GetTupleElement(element, i);
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h
index ef3c8badf71..549ccd5aece 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h
@@ -60,17 +60,22 @@ Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
 
 // Returns an uninitialized TensorList.
 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
-                                        int64 leading_dimension);
+                                        int64 leading_dimension,
+                                        bool leading_size_is_dynamic,
+                                        xla::XlaOp leading_dim_size);
 
-// Returns leading dimension for the TensorList.
-// Input can be initialized or uninitialized TensorList.
-// Non-nested and nested TensorLists are both supported.
-Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim);
+// Returns leading dimension for the TensorList as well as a dynamic op
+// representing the dynamic size. Input can be initialized or uninitialized
+// TensorList. Non-nested and nested TensorLists are both supported.
+Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim,
+                                  bool* leading_dim_is_dynamic,
+                                  xla::XlaOp* leading_dim_dynamic_size);
 
 // Returns TensorList shape for the element shape.
 // Element shape must be a normal tensor shape.
 Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
                                           int64 leading_dim,
+                                          bool leading_dim_is_dynamic,
                                           xla::Shape* tensor_list_shape);
 
 // Returns a TensorList filled by zeros with the given shape.
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index fe7a5898011..a94411f1b30 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -513,10 +513,26 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
         // Prepare dynamic dimensions for element shapes.
         std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
         for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
-          // Set dynamic dimension size to 0 for initilization value.
           std::vector<xla::XlaOp> dynamic_dims;
+
           const xla::Shape& shape = list_shape.tuple_shapes(i);
-          for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
+
+          // We already have the dynamic size of leading dimension outside of
+          // the while loop without initializing the TensorList inside the while
+          // loop.
+          if (shape.is_dynamic_dimension(0)) {
+            xla::XlaOp leading_dim_size = xla::GetDimensionSize(input, 0);
+            dynamic_dims.push_back(leading_dim_size);
+          } else {
+            int32 dim_size = shape.dimensions(0);
+            dynamic_dims.push_back(
+                xla::ConstantR0<int32>(ctx->builder(), dim_size));
+          }
+
+          // Set dynamic dimension size to 0 for element value. Inside the while
+          // loop, TensorlistSetItem will properly set the element shape's
+          // dynamic diemnsion.
+          for (int64 dim = 1; dim < shape.dimensions_size(); ++dim) {
             int32 dim_size = shape.dimensions(dim);
             if (shape.is_dynamic_dimension(dim)) {
               dim_size = 0;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 07537546d52..c2d1906e47a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -259,6 +259,32 @@ static Status LiteralToPredVector(const xla::LiteralSlice& literal,
   return Status::OK();
 }
 
+Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
+  xla::Literal literal;
+  XlaExpression e = InputExpression(index);
+  auto* client = compiler() ? compiler()->client() : nullptr;
+  xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
+  if (!dynamism_or_status.ok()) {
+    Status status = dynamism_or_status.status();
+    errors::AppendToMessage(&status, "while evaluating input dynamism", index,
+                            " of ", context_->op_kernel().type_string());
+    return status;
+  }
+  Tensor dynamism = dynamism_or_status.ValueOrDie();
+
+  Tensor temp(dynamism.dtype());
+  TensorShape tensor_shape({});
+  if (!temp.CopyFrom(dynamism, tensor_shape)) {
+    return errors::InvalidArgument(
+        context_->op_kernel().name(), " input ", index, " has shape ",
+        dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
+  }
+
+  TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
+  *out = literal.Get<bool>({});
+  return Status::OK();
+}
+
 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
     int index, std::vector<bool>* out) {
   xla::Literal literal;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 75c3e60171a..1ed343ba20f 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -119,7 +119,7 @@ class XlaOpKernelContext {
   // Evaluates input and returns their dynamism vector in a vector of
   // predicates.
   Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
-
+  Status ResolveInputDynamismIntoPred(int index, bool* out);
   // Helper methods for constant inputs.
 
   // Evaluates input `index` and stores it in `*constant_literal`. If the