diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc
index d58b1cd74c4..cc7ce6b39dc 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc
@@ -96,6 +96,9 @@ bool TensorFlowRefType::classof(Type type) {
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
       >();
 }
+bool TensorFlowTypeWithSubtype::classof(Type type) {
+  return type.isa<ResourceType, VariantType>();
+}
 
 TensorFlowType TensorFlowRefType::get(Type type) {
   MLIRContext* ctx = type.getContext();
@@ -178,10 +181,6 @@ Type TensorFlowRefType::RemoveRef() {
   llvm_unreachable("unexpected tensorflow ref type kind");
 }
 
-bool TensorFlowTypeWithSubtype::classof(Type type) {
-  return type.isa<ResourceType, VariantType>();
-}
-
 Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
   MLIRContext* ctx = getContext();
   if (isa<VariantType>()) return VariantType::get(ctx);
@@ -189,18 +188,6 @@ Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
   llvm_unreachable("unexpected tensorflow type with subtypes kind");
 }
 
-TensorFlowTypeWithSubtype TensorFlowTypeWithSubtype::clone(
-    ArrayRef<TensorType> new_subtypes) {
-  MLIRContext* ctx = getContext();
-  if (isa<VariantType>())
-    return VariantType::get(new_subtypes, ctx)
-        .cast<TensorFlowTypeWithSubtype>();
-  if (isa<ResourceType>())
-    return ResourceType::get(new_subtypes, ctx)
-        .cast<TensorFlowTypeWithSubtype>();
-  llvm_unreachable("unexpected tensorflow type with subtypes kind");
-}
-
 ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
   if (auto variant_type = dyn_cast<VariantType>())
     return variant_type.getSubtypes();
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h
index 0ba3201b95d..57ff9dce272 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h
@@ -246,9 +246,6 @@ class TensorFlowTypeWithSubtype : public TensorFlowType {
   // Converts a TypeWithSubtype type to the same type but without its subtypes.
   Type RemoveSubtypes();
 
-  // Clone the current Type with new subtypes.
-  TensorFlowTypeWithSubtype clone(ArrayRef<TensorType> new_subtypes);
-
   // Returns the subtypes.
   ArrayRef<TensorType> GetSubtypes();
 };
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 9b43e3dcd17..9a2f660c74a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -94,120 +94,20 @@ bool CanBeRefined(Type type) {
   return !shape_type.hasStaticShape();
 }
 
-// Compute a refined type between two types `lhs` and `rhs`, the result type
-// is always more refined (i.e. has more static information) than `lhs`
-// This method will actually merge the information contained in the
-// types, it is capable of refining:
-//   tensor<!tf.variant<tensor<?x8xf32>>>
-// and:
-//   tensor<!tf.variant<tensor<10x?xf32>>>
-// into:
-//   tensor<!tf.variant<tensor<10x8xf32>>>
-//
-// In case of inconsistencies (rank disagreement for example), it returns `lhs`.
-Type TypeMeet(Type lhs, Type rhs) {
-  DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs);
-  if (lhs == rhs) return lhs;
-
-  auto rhs_shape_type = rhs.dyn_cast<ShapedType>();
-  if (!rhs_shape_type) return lhs;
-  auto lhs_shape_type = lhs.cast<ShapedType>();
-  if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() &&
-      lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
-    DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs);
-    return lhs;
-  }
-
-  SmallVector<int64_t> shape;
-  bool refined_shape = false;
-  // Build the shape of the refined type, if lhs is unranked it
-  // will be directly the shape of the refined type, otherwise we merged by
-  // taking the most specialized. This combines `10x?x?` and `?x?x8` into
-  // `10x?x8`.
-  if (!lhs_shape_type.hasRank()) {
-    if (rhs_shape_type.hasRank()) {
-      shape.append(rhs_shape_type.getShape().begin(),
-                   rhs_shape_type.getShape().end());
-      refined_shape = true;
-    }
-  } else if (rhs_shape_type.hasRank()) {
-    for (auto shape_elts : llvm::enumerate(
-             llvm::zip(lhs_shape_type.getShape(), rhs_shape_type.getShape()))) {
-      if (ShapedType::isDynamic(std::get<0>(shape_elts.value())) &&
-          !ShapedType::isDynamic(std::get<1>(shape_elts.value()))) {
-        shape.push_back(std::get<1>(shape_elts.value()));
-        refined_shape = true;
-        DCOMMENT("-> refining shape element #" << shape_elts.index());
-      } else {
-        DCOMMENT("-> not refining shape element #" << shape_elts.index());
-        shape.push_back(std::get<0>(shape_elts.value()));
-      }
-    }
-  }
-
-  // Some tensor have an element type wrapping a subtensor, like resource and
-  // variants. In this case we may recurse on the wrapped subtype.
-  // `element_type` will contain the refined inferred element type for the
-  // returned type.
-  auto lhs_element_type = lhs_shape_type.getElementType();
-  auto rhs_element_type_with_subtype =
-      rhs_shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
-  // Look for resource or variant element type and ensure we refine the subtype.
-  // We only support a single subtype at the moment, we won't handle something
-  // like:
-  //   tensor<!tf.variant<tensor<10xf32>, tensor<8xf32>>
-  if (rhs_element_type_with_subtype &&
-      rhs_element_type_with_subtype.GetSubtypes().size() == 1) {
-    auto lhs_element_type_with_subtype =
-        lhs_element_type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
-    TensorType subtype;
-    if (!lhs_element_type_with_subtype) {
-      DCOMMENT(
-          "Unexpected inferred `TensorFlowTypeWithSubtype` when original "
-          "result isn't");
-    } else if (lhs_element_type_with_subtype.GetSubtypes().size() > 1) {
-      DCOMMENT(
-          "Unexpected `TensorFlowTypeWithSubtype` original type with size>1");
-    } else if (lhs_element_type_with_subtype.GetSubtypes().empty()) {
-      subtype = rhs_element_type_with_subtype.GetSubtypes().front();
-    } else {
-      // Recurse on the subtypes in the variant/resource. Basically if the input
-      // were:
-      //   tensor<!tf.variant<tensor<?x8xf32>>>
-      // and:
-      //   tensor<!tf.variant<tensor<10x8xf32>>>
-      // we'll try here to refine tensor<?x8xf32> with tensor<10x8xf32>.
-      auto refined_subtype =
-          TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(),
-                   rhs_element_type_with_subtype.GetSubtypes().front())
-              .cast<TensorType>();
-      if (refined_subtype !=
-          lhs_element_type_with_subtype.GetSubtypes().front())
-        subtype = refined_subtype;
-    }
-    // If we managed to refine the subtype, recreate the element type itself
-    // (i.e. the tf.variant or tf.resource).
-    if (subtype) {
-      lhs_element_type = lhs_element_type_with_subtype.clone({subtype});
-    }
-  }
-  if (refined_shape || lhs_element_type != lhs_shape_type.getElementType()) {
-    Type new_type;
-    if (!lhs_shape_type.hasRank() && !rhs_shape_type.hasRank())
-      new_type = UnrankedTensorType::get(lhs_element_type);
-    else
-      new_type = lhs_shape_type.clone(shape, lhs_element_type);
-    DCOMMENT("Refined to: " << new_type);
-    return new_type;
-  }
-  DCOMMENT("No refinement " << lhs);
-  return lhs;
-}
-
 // Returns whether `original_type` type can be refined with
 // `potential_refined_type` type.
 bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
-  return original_type != TypeMeet(original_type, potential_refined_type);
+  if (original_type == potential_refined_type || !CanBeRefined(original_type))
+    return false;
+
+  auto shape_type = potential_refined_type.dyn_cast<ShapedType>();
+  if (!shape_type) return false;
+  if (shape_type.hasRank()) return true;
+
+  auto element_type_with_subtype =
+      shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
+  return element_type_with_subtype &&
+         !element_type_with_subtype.GetSubtypes().empty();
 }
 
 // Returns if the shape inference pass supports an op outside the TF dialect.
@@ -799,7 +699,7 @@ bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
   FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable());
   if (!func) return false;
 
-  DCOMMENT("Infer shape for call " << func.getName());
+  LLVM_DEBUG(llvm::dbgs() << "Infer shape for call " << func.getName());
   Operation* op = call_op.getOperation();
   bool changed = false;
   // Map each of the results of the call to the returned type of the
@@ -808,7 +708,7 @@ bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
     changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
               changed;
   }
-  DCOMMENT(" - call " << func.getName() << "changed ? " << changed << "\n");
+  LLVM_DEBUG(llvm::dbgs() << " changed ? " << changed << "\n");
 
   return changed;
 }
@@ -885,7 +785,7 @@ bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
     DCOMMENT("InferShapeForListInitOps " << op << " could not infer");
     return false;
   }
-  DCOMMENT("InferShapeForListInitOps " << *op << " could be inferred "
+  DCOMMENT("InferShapeForListInitOps " << op << " could be inferred "
                                        << element_type);
   if (!element_type || !element_type.hasStaticShape()) return false;
   auto variant_type = VariantType::get(element_type, op->getContext());