diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index a462a7f4a1f..7c7f6f306cf 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -2903,6 +2903,22 @@ func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -
   return %0 : tensor<256x30x30x16xf32>
 }
 
+// CHECK-LABEL: depthwiseconv_simple
+func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> {
+  // CHECK: %[[RESHAPED_FILTER:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32>
+  // CHECK: "xla_hlo.conv"(%arg0, %[[RESHAPED_FILTER]])
+  // CHECK: feature_group_count = 3
+  %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {
+    data_format = "NHWC",
+    device = "",
+    dilations = [1, 1, 1, 1],
+    explicit_paddings = [],
+    padding = "VALID",
+    strides = [1, 1, 1, 1]
+  } : (tensor<2x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32>
+  return %0 : tensor<2x3x4x9xf32>
+}
+
 // CHECK-LABEL: conv_valid_padding
 func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> {
   // CHECK: "xla_hlo.conv"(%arg0, %arg1)
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index f8ec40aa42d..aa6ac85b4af 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -777,7 +777,7 @@ NamedAttribute GetConvDimensionNumbersAttr(
 // the paddings attribute anyway requires multiple source op attributes and
 // result op attributes. Defining it as declarative rewrite rule will introduce
 // some duplication in the C++ helper methods.
-template <typename OpT, int num_spatial_dims>
+template <typename OpT, int num_spatial_dims, bool depthwise_conv = false>
 class ConvertConv : public OpRewritePattern<OpT> {
  public:
   using OpRewritePattern<OpT>::OpRewritePattern;
@@ -868,7 +868,10 @@ class ConvertConv : public OpRewritePattern<OpT> {
     int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
     // TensorFlow convolution op verifies that the number of input channels is
     // divisible by the number of filter channels.
-    int64_t feature_group_count = input_channels / filter_channels;
+    // For depthwise convolution the feature_group_count argument would be set
+    // to the input feature dimension.
+    int64_t feature_group_count =
+        depthwise_conv ? input_channels : input_channels / filter_channels;
     auto feature_group_count_attr = rewriter.getNamedAttr(
         "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
 
@@ -881,6 +884,22 @@ class ConvertConv : public OpRewritePattern<OpT> {
         "padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
 
     SmallVector<Value, 2> operands(op.getOperands());
+    // Reshape the filter to {spatial_dims...., 1,in_channels *
+    // channel_multiplier}
+    if (depthwise_conv) {
+      auto filter_shape = filter_ty.getShape();
+      llvm::SmallVector<int64_t, 4> new_shape(filter_shape.size());
+      for (int i = 0; i < num_spatial_dims; ++i) {
+        new_shape[i] = filter_shape[i];
+      }
+      new_shape[num_spatial_dims] = 1;
+      new_shape[num_spatial_dims + 1] =
+          filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1];
+      operands[1] = rewriter.create<xla_hlo::ReshapeOp>(
+          op.getLoc(),
+          RankedTensorType::get(new_shape, filter_ty.getElementType()),
+          operands[1]);
+    }
     NamedAttribute attrs[] = {rhs_dilations_attr,     window_strides_attr,
                               dimension_numbers_attr, feature_group_count_attr,
                               batch_group_count_attr, paddings_attr};
@@ -891,7 +910,9 @@ class ConvertConv : public OpRewritePattern<OpT> {
 };
 
 using ConvertConv2D = ConvertConv<TF::Conv2DOp, /*num_spatial_dims=*/2>;
-
+using ConvertDepthConv2D =
+    ConvertConv<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
+                /*depthwise_conv=*/true>;
 // Converts BF16 FloorDiv op to have casting operators on either end as BF16
 // division can result in strange behavior.
 //
@@ -3767,15 +3788,15 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
   patterns.insert<
       ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
       ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2D,
-      ConvertConv2DBackpropFilterOp, ConvertConv2DBackpropInputOp,
-      ConvertCumsumOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp,
-      ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
-      ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp,
-      ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp,
-      ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
-      ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp,
-      ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
-      ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
+      ConvertDepthConv2D, ConvertConv2DBackpropFilterOp,
+      ConvertConv2DBackpropInputOp, ConvertCumsumOp, ConvertEinsumOp,
+      ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
+      ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
+      ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, ConvertMaxOp,
+      ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp, ConvertMaxPoolGradOp,
+      ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
+      ConvertProdOp, ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp,
+      ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
       ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
       ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
       ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,