From 0c0555d94ba1f641bb722a918d3acf6fd4742fdc Mon Sep 17 00:00:00 2001
From: Jaesung Chung <jaesung@google.com>
Date: Wed, 6 May 2020 23:23:12 -0700
Subject: [PATCH] Split merged TFLiteConverter implementations into frozen
 graph converter and saved model converter

PiperOrigin-RevId: 310301390
Change-Id: I021c1fa678d6367226e1a19e646bb6d0ff9769e3
---
 tensorflow/lite/python/lite.py                | 1394 ++++++++++-------
 tensorflow/lite/python/lite_test.py           |   37 +
 tensorflow/lite/python/lite_v2_test.py        |   21 +
 .../tensorflow.lite.-t-f-lite-converter.pbtxt |    4 +-
 .../tensorflow.lite.-t-f-lite-converter.pbtxt |    4 +-
 5 files changed, 910 insertions(+), 550 deletions(-)

diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py
index 15704a6da76..b2d58ec8746 100644
--- a/tensorflow/lite/python/lite.py
+++ b/tensorflow/lite/python/lite.py
@@ -296,9 +296,9 @@ class TFLiteConverterBase(object):
     # The 'GraphDebugInfo'  contains the stack traces of all the original nodes
     # in the `GraphDef` to the converter.
     self._debug_info = None
-    self._saved_model_dir = None
+    self.saved_model_dir = None
     self._saved_model_tags = None
-    self._saved_model_version = None
+    self._saved_model_version = 0
     self._saved_model_exported_names = []
     self._experimental_sparsify_model = False
 
@@ -366,9 +366,9 @@ class TFLiteConverterBase(object):
         "enable_mlir_converter": self.experimental_new_converter,
     }
 
-    if self._saved_model_dir:
+    if self.saved_model_dir:
       args.update({
-          "saved_model_dir": self._saved_model_dir,
+          "saved_model_dir": self.saved_model_dir,
           "saved_model_version": self._saved_model_version,
           "saved_model_tags": self._saved_model_tags,
           "saved_model_exported_names": self._saved_model_exported_names,
@@ -387,19 +387,19 @@ class TFLiteConverterBase(object):
   def _parse_saved_model_args(self):
     """Parses SavedModel arguments from the given Keras/RNN SavedModel."""
     if not self.experimental_new_converter:
-      self._saved_model_dir = None
+      self.saved_model_dir = None
       return
-    if self._saved_model_dir:
+    if self.saved_model_dir:
       try:
         saved_model_proto, _ = (
-            _parse_saved_model_with_debug_info(self._saved_model_dir))
+            _parse_saved_model_with_debug_info(self.saved_model_dir))
       except OSError:
         # If it fails to read the given saved model, it will fall back to the
         # frozen graph def path.
-        self._saved_model_dir = None
+        self.saved_model_dir = None
         return
       if not self._contains_function_with_implements_attr(saved_model_proto):
-        self._saved_model_dir = None
+        self.saved_model_dir = None
       else:
         self._saved_model_exported_names = []
         self._saved_model_version = saved_model_proto.saved_model_schema_version
@@ -409,179 +409,18 @@ class TFLiteConverterBase(object):
                   self._saved_model_version))
 
 
-@_tf_export("lite.TFLiteConverter", v1=[])
-class TFLiteConverterV2(TFLiteConverterBase):
-  """Converts a TensorFlow model into TensorFlow Lite model.
+class TFLiteConverterBaseV2(TFLiteConverterBase):
+  """Converter subclass to share functionality between V2 converters."""
 
-  Attributes:
-    allow_custom_ops: Boolean indicating whether to allow custom operations.
-      When false any unknown operation is an error. When true, custom ops are
-      created for any op that is unknown. The developer will need to provide
-      these to the TensorFlow Lite runtime with a custom resolver.
-      (default False)
-    target_spec: Experimental flag, subject to change. Specification of target
-      device.
-    optimizations: Experimental flag, subject to change. A list of optimizations
-      to apply when converting the model. E.g. `[Optimize.DEFAULT]`
-    representative_dataset: A representative dataset that can be used to
-      generate input and output samples for the model. The converter can use the
-      dataset to evaluate different optimizations. Note that this is an optional
-      attribute but it is necessary if INT8 is the only support builtin ops in
-      target ops.
-    experimental_new_converter: Experimental flag, subject to change.
-      Enables MLIR-based conversion instead of TOCO conversion.
-  Example usage:
-
-    ```python
-    # Converting a SavedModel to a TensorFlow Lite model.
-    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
-    tflite_model = converter.convert()
-
-    # Converting a tf.Keras model to a TensorFlow Lite model.
-    converter = lite.TFLiteConverter.from_keras_model(model)
-    tflite_model = converter.convert()
-
-    # Converting ConcreteFunctions to a TensorFlow Lite model.
-    converter = lite.TFLiteConverter.from_concrete_functions([func])
-    tflite_model = converter.convert()
-    ```
-  """
-
-  def __init__(self,
-               funcs,
-               trackable_obj=None,
-               saved_model_dir=None,
-               saved_model_tags=None):
-    """Constructor for TFLiteConverter.
-
-    Args:
-      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
-        duplicate elements.
-      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
-        reference to this object needs to be maintained so that Variables do not
-        get garbage collected since functions have a weak reference to
-        Variables. This is only required when the tf.AutoTrackable object is not
-        maintained by the user (e.g. `from_saved_model`).
-      saved_model_dir: Directory of the SavedModel. This argument can be null
-        when it creates via the from_keras_model and from_concrete_function
-        methods.
-      saved_model_tags: Set of tags identifying the MetaGraphDef within the
-        SavedModel to analyze. All tags in the tag set must be present. (default
-        set(SERVING)).  This argument will be available when the saved model dir
-        argument is set.
-    """
-    super(TFLiteConverterV2, self).__init__()
-    self._funcs = funcs
-    self._trackable_obj = trackable_obj
-    self._saved_model_dir = saved_model_dir
-    self._saved_model_tags = saved_model_tags
-
-  @classmethod
-  def from_concrete_functions(cls, funcs):
-    """Creates a TFLiteConverter object from ConcreteFunctions.
-
-    Args:
-      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
-        duplicate elements. Currently converter can only convert a single
-        ConcreteFunction. Converting multiple functions is under development.
-
-    Returns:
-      TFLiteConverter object.
-
-    Raises:
-      Invalid input type.
-    """
-    for func in funcs:
-      if not isinstance(func, _function.ConcreteFunction):
-        message = "This function takes in a list of ConcreteFunction."
-        if isinstance(func, _def_function.Function):
-          message += (" To get the ConcreteFunction from a Function,"
-                      " call get_concrete_function.")
-        raise ValueError(message)
-    return cls(funcs)
-
-  @classmethod
-  def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None):
-    """Creates a TFLiteConverter object from a SavedModel directory.
-
-    Args:
-      saved_model_dir: SavedModel directory to convert.
-      signature_keys: List of keys identifying SignatureDef containing inputs
-        and outputs. Elements should not be duplicated. By default the
-        `signatures` attribute of the MetaGraphdef is used. (default
-        saved_model.signatures)
-      tags: Set of tags identifying the MetaGraphDef within the SavedModel to
-        analyze. All tags in the tag set must be present. (default set(SERVING))
-
-    Returns:
-      TFLiteConverter object.
-
-    Raises:
-      Invalid signature keys.
-    """
-    # When run without eager enabled, this will return the legacy
-    # TFLiteConverter.
-    if not context.executing_eagerly():
-      signature_key = None
-      if signature_keys:
-        if len(signature_keys) != 1:
-          raise ValueError("Only support a single signature key.")
-        else:
-          signature_key = signature_keys[0]
-      logging.warning("Invoking the TF1 implementation of TFLiteConverter "
-                      "because eager is disabled. Consider enabling eager.")
-      return TFLiteConverter.from_saved_model(saved_model_dir,
-                                              signature_key=signature_key,
-                                              tag_set=tags)
-
-    # Ensures any graphs created in Eager mode are able to run. This is required
-    # in order to create a tf.estimator.Exporter that exports a TFLite model.
-    if tags is None:
-      tags = set([_tag_constants.SERVING])
-
-    with context.eager_mode():
-      saved_model = _load(saved_model_dir, tags)
-    if not signature_keys:
-      signature_keys = saved_model.signatures
-
-    funcs = []
-    for key in signature_keys:
-      if key not in saved_model.signatures:
-        raise ValueError("Invalid signature key '{}' found. Valid keys are "
-                         "'{}'.".format(key, ",".join(saved_model.signatures)))
-      funcs.append(saved_model.signatures[key])
-
-    return cls(funcs, saved_model, saved_model_dir, tags)
-
-  @classmethod
-  def from_keras_model(cls, model):
-    """Creates a TFLiteConverter object from a Keras model.
-
-    Args:
-      model: tf.Keras.Model
-
-    Returns:
-      TFLiteConverter object.
-    """
-    input_signature = None
-    # If the model's call is not a `tf.function`, then we need to first get its
-    # input signature from `model_input_signature` method. We can't directly
-    # call `trace_model_call` because otherwise the batch dimension is set
-    # to None.
-    # Once we have better support for dynamic shapes, we can remove this.
-    if not isinstance(model.call, _def_function.Function):
-      # Pass `keep_original_batch_size=True` will ensure that we get an input
-      # signature including the batch dimension specified by the user.
-      input_signature = _saving_utils.model_input_signature(
-          model, keep_original_batch_size=True)
-
-    func = _saving_utils.trace_model_call(model, input_signature)
-    concrete_func = func.get_concrete_function()
-    return cls([concrete_func])
-
-  def convert(self):
+  def _convert(self, graph_def, input_tensors, output_tensors):
     """Converts a TensorFlow GraphDef based on instance variables.
 
+    Args:
+      graph_def: Frozen TensorFlow GraphDef.
+      input_tensors: List of input tensors. Type and shape are computed using
+        `foo.shape` and `foo.dtype`.
+      output_tensors: List of output tensors (only .name is used from this).
+
     Returns:
       The converted data in serialized format.
 
@@ -592,62 +431,6 @@ class TFLiteConverterV2(TFLiteConverterBase):
         Input shape is not specified.
         Invalid quantization parameters.
     """
-    # TODO(b/130297984): Add support for converting multiple function.
-
-    if len(self._funcs) == 0:
-      raise ValueError("No ConcreteFunction is specified.")
-
-    if len(self._funcs) > 1:
-      raise ValueError("This converter can only convert a single "
-                       "ConcreteFunction. Converting multiple functions is "
-                       "under development.")
-
-    # Parses SavedModel argument.
-    self._parse_saved_model_args()
-
-    # graph_def is used here to preserve the node bug information
-    if self._saved_model_dir:
-      graph = _ops.Graph()
-      saved_model = _loader_impl.SavedModelLoader(self._saved_model_dir)
-      saved_model.load_graph(graph, tags=self._saved_model_tags)
-      meta_graph = saved_model.get_meta_graph_def_from_tags(
-          self._saved_model_tags)
-      signature_def = meta_graph.signature_def[
-          _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
-      input_tensors = [
-          graph.get_tensor_by_name(signature_def.inputs[key].name)
-          for key in signature_def.inputs
-      ]
-      output_tensors = [
-          graph.get_tensor_by_name(signature_def.outputs[key].name)
-          for key in signature_def.outputs
-      ]
-      self._graph_def = graph_def = meta_graph.graph_def
-    else:
-      frozen_func, graph_def = (
-          _convert_to_constants.convert_variables_to_constants_v2_as_graph(
-              self._funcs[0], lower_control_flow=False))
-      self._graph_def = graph_def
-
-      input_tensors = [
-          tensor for tensor in frozen_func.inputs
-          if tensor.dtype != _dtypes.resource
-      ]
-      output_tensors = frozen_func.outputs
-
-      # Run a Grappler pass.
-      grappler_config = self._grappler_config()
-      # Skip running grappler when there are no optimizers to run. If not,
-      # grappler will run with the default optimizer set and it will lead to
-      # causing an unexpected behavior.
-      if grappler_config.graph_options.rewrite_options.optimizers:
-        graph_def = _run_graph_optimizations(
-            graph_def,
-            input_tensors,
-            output_tensors,
-            config=grappler_config,
-            graph=frozen_func.graph)
-
     quant_mode = QuantizationMode(self.optimizations, self.target_spec,
                                   self.representative_dataset, graph_def)
 
@@ -724,8 +507,820 @@ class TFLiteConverterV2(TFLiteConverterBase):
     return result
 
 
+class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2):
+  """Converts the given SavedModel into TensorFlow Lite model.
+
+  Attributes:
+      saved_model_dir: Directory of the SavedModel.
+  """
+
+  def __init__(self,
+               saved_model_dir,
+               saved_model_tags=None,
+               saved_model_exported_names=None,
+               trackable_obj=None):
+    """Constructor for TFLiteConverter.
+
+    Args:
+      saved_model_dir: Directory of the SavedModel.
+      saved_model_tags: Set of tags identifying the MetaGraphDef within the
+        SavedModel to analyze. All tags in the tag set must be present. (default
+        set(SERVING)).
+      saved_model_exported_names: Names to be exported (default: export all)
+        when the saved model import path is on.
+      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
+        reference to this object needs to be maintained so that Variables do not
+        get garbage collected since functions have a weak reference to
+        Variables. This is only required when the tf.AutoTrackable object is not
+        maintained by the user (e.g. `from_saved_model`).
+    """
+    super(TFLiteSavedModelConverterV2, self).__init__()
+    self.saved_model_dir = saved_model_dir
+    self._saved_model_tags = saved_model_tags
+    self._saved_model_exported_names = saved_model_exported_names
+    self._trackable_obj = trackable_obj
+    self._parse_saved_model_args()
+
+  def convert(self):
+    """Converts a TensorFlow GraphDef based on instance variables.
+
+    Returns:
+      The converted data in serialized format.
+
+    Raises:
+      ValueError:
+        No concrete functions is specified.
+        Multiple concrete functions are specified.
+        Input shape is not specified.
+        Invalid quantization parameters.
+    """
+    graph = _ops.Graph()
+    saved_model = _loader_impl.SavedModelLoader(self.saved_model_dir)
+    saved_model.load_graph(graph, tags=self._saved_model_tags)
+    meta_graph = saved_model.get_meta_graph_def_from_tags(
+        self._saved_model_tags)
+    signature_def = meta_graph.signature_def[
+        _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
+    input_tensors = [
+        graph.get_tensor_by_name(signature_def.inputs[key].name)
+        for key in signature_def.inputs
+    ]
+    output_tensors = [
+        graph.get_tensor_by_name(signature_def.outputs[key].name)
+        for key in signature_def.outputs
+    ]
+    return self._convert(meta_graph.graph_def, input_tensors, output_tensors)
+
+
+class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
+  """Converts the given frozen graph into TensorFlow Lite model."""
+
+  def __init__(self, funcs, trackable_obj=None):
+    """Constructor for TFLiteConverter.
+
+    Args:
+      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
+        duplicate elements.
+      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
+        reference to this object needs to be maintained so that Variables do not
+        get garbage collected since functions have a weak reference to
+        Variables. This is only required when the tf.AutoTrackable object is not
+        maintained by the user (e.g. `from_saved_model`).
+    """
+    super(TFLiteFrozenGraphConverterV2, self).__init__()
+    self._funcs = funcs
+    self._trackable_obj = trackable_obj
+
+  def convert(self):
+    """Converts a TensorFlow GraphDef based on instance variables.
+
+    Returns:
+      The converted data in serialized format.
+
+    Raises:
+      ValueError:
+        No concrete functions is specified.
+        Multiple concrete functions are specified.
+        Input shape is not specified.
+        Invalid quantization parameters.
+    """
+    # TODO(b/130297984): Add support for converting multiple function.
+
+    if len(self._funcs) == 0:
+      raise ValueError("No ConcreteFunction is specified.")
+
+    if len(self._funcs) > 1:
+      raise ValueError("This converter can only convert a single "
+                       "ConcreteFunction. Converting multiple functions is "
+                       "under development.")
+
+    frozen_func, graph_def = (
+        _convert_to_constants.convert_variables_to_constants_v2_as_graph(
+            self._funcs[0], lower_control_flow=False))
+
+    input_tensors = [
+        tensor for tensor in frozen_func.inputs
+        if tensor.dtype != _dtypes.resource
+    ]
+    output_tensors = frozen_func.outputs
+
+    # Run a Grappler pass.
+    grappler_config = self._grappler_config()
+    # Skip running grappler when there are no optimizers to run. If not,
+    # grappler will run with the default optimizer set and it will lead to
+    # causing an unexpected behavior.
+    if grappler_config.graph_options.rewrite_options.optimizers:
+      graph_def = _run_graph_optimizations(
+          graph_def,
+          input_tensors,
+          output_tensors,
+          config=grappler_config,
+          graph=frozen_func.graph)
+
+    return self._convert(graph_def, input_tensors, output_tensors)
+
+
+@_tf_export("lite.TFLiteConverter", v1=[])
+class TFLiteConverterV2(TFLiteFrozenGraphConverterV2):
+  """Converts a TensorFlow model into TensorFlow Lite model.
+
+  Attributes:
+    allow_custom_ops: Boolean indicating whether to allow custom operations.
+      When false any unknown operation is an error. When true, custom ops are
+      created for any op that is unknown. The developer will need to provide
+      these to the TensorFlow Lite runtime with a custom resolver.
+      (default False)
+    target_spec: Experimental flag, subject to change. Specification of target
+      device.
+    optimizations: Experimental flag, subject to change. A list of optimizations
+      to apply when converting the model. E.g. `[Optimize.DEFAULT]`
+    representative_dataset: A representative dataset that can be used to
+      generate input and output samples for the model. The converter can use the
+      dataset to evaluate different optimizations. Note that this is an optional
+      attribute but it is necessary if INT8 is the only support builtin ops in
+      target ops.
+    experimental_new_converter: Experimental flag, subject to change.
+      Enables MLIR-based conversion instead of TOCO conversion.
+  Example usage:
+
+    ```python
+    # Converting a SavedModel to a TensorFlow Lite model.
+    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
+    tflite_model = converter.convert()
+
+    # Converting a tf.Keras model to a TensorFlow Lite model.
+    converter = lite.TFLiteConverter.from_keras_model(model)
+    tflite_model = converter.convert()
+
+    # Converting ConcreteFunctions to a TensorFlow Lite model.
+    converter = lite.TFLiteConverter.from_concrete_functions([func])
+    tflite_model = converter.convert()
+    ```
+  """
+
+  # pylint: disable=useless-super-delegation
+  def __init__(self, funcs, trackable_obj=None):
+    """Constructor for TFLiteConverter.
+
+    Args:
+      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
+        duplicate elements.
+      trackable_obj: tf.AutoTrackable object associated with `funcs`. A
+        reference to this object needs to be maintained so that Variables do not
+        get garbage collected since functions have a weak reference to
+        Variables. This is only required when the tf.AutoTrackable object is not
+        maintained by the user (e.g. `from_saved_model`).
+    """
+    super(TFLiteConverterV2, self).__init__(funcs, trackable_obj)
+
+  @classmethod
+  def from_concrete_functions(cls, funcs):
+    """Creates a TFLiteConverter object from ConcreteFunctions.
+
+    Args:
+      funcs: List of TensorFlow ConcreteFunctions. The list should not contain
+        duplicate elements. Currently converter can only convert a single
+        ConcreteFunction. Converting multiple functions is under development.
+
+    Returns:
+      TFLiteConverter object.
+
+    Raises:
+      Invalid input type.
+    """
+    for func in funcs:
+      if not isinstance(func, _function.ConcreteFunction):
+        message = "This function takes in a list of ConcreteFunction."
+        if isinstance(func, _def_function.Function):
+          message += (" To get the ConcreteFunction from a Function,"
+                      " call get_concrete_function.")
+        raise ValueError(message)
+    return cls(funcs)
+
+  @classmethod
+  def from_saved_model(cls, saved_model_dir, signature_keys=None, tags=None):
+    """Creates a TFLiteConverter object from a SavedModel directory.
+
+    Args:
+      saved_model_dir: SavedModel directory to convert.
+      signature_keys: List of keys identifying SignatureDef containing inputs
+        and outputs. Elements should not be duplicated. By default the
+        `signatures` attribute of the MetaGraphdef is used. (default
+        saved_model.signatures)
+      tags: Set of tags identifying the MetaGraphDef within the SavedModel to
+        analyze. All tags in the tag set must be present. (default set(SERVING))
+
+    Returns:
+      TFLiteConverter object.
+
+    Raises:
+      Invalid signature keys.
+    """
+    # When run without eager enabled, this will return the legacy
+    # TFLiteConverter.
+    if not context.executing_eagerly():
+      signature_key = None
+      if signature_keys:
+        if len(signature_keys) != 1:
+          raise ValueError("Only support a single signature key.")
+        else:
+          signature_key = signature_keys[0]
+      logging.warning("Invoking the TF1 implementation of TFLiteConverter "
+                      "because eager is disabled. Consider enabling eager.")
+      return TFLiteConverter.from_saved_model(saved_model_dir,
+                                              signature_key=signature_key,
+                                              tag_set=tags)
+
+    # Ensures any graphs created in Eager mode are able to run. This is required
+    # in order to create a tf.estimator.Exporter that exports a TFLite model.
+    if tags is None:
+      tags = set([_tag_constants.SERVING])
+
+    with context.eager_mode():
+      saved_model = _load(saved_model_dir, tags)
+    if not signature_keys:
+      signature_keys = saved_model.signatures
+
+    funcs = []
+    for key in signature_keys:
+      if key not in saved_model.signatures:
+        raise ValueError("Invalid signature key '{}' found. Valid keys are "
+                         "'{}'.".format(key, ",".join(saved_model.signatures)))
+      funcs.append(saved_model.signatures[key])
+
+    saved_model_converter = TFLiteSavedModelConverterV2(saved_model_dir, tags,
+                                                        signature_keys,
+                                                        saved_model)
+    if saved_model_converter.saved_model_dir:
+      return saved_model_converter
+
+    return cls(funcs, saved_model)
+
+  @classmethod
+  def from_keras_model(cls, model):
+    """Creates a TFLiteConverter object from a Keras model.
+
+    Args:
+      model: tf.Keras.Model
+
+    Returns:
+      TFLiteConverter object.
+    """
+    input_signature = None
+    # If the model's call is not a `tf.function`, then we need to first get its
+    # input signature from `model_input_signature` method. We can't directly
+    # call `trace_model_call` because otherwise the batch dimension is set
+    # to None.
+    # Once we have better support for dynamic shapes, we can remove this.
+    if not isinstance(model.call, _def_function.Function):
+      # Pass `keep_original_batch_size=True` will ensure that we get an input
+      # signature including the batch dimension specified by the user.
+      input_signature = _saving_utils.model_input_signature(
+          model, keep_original_batch_size=True)
+
+    func = _saving_utils.trace_model_call(model, input_signature)
+    concrete_func = func.get_concrete_function()
+    return cls([concrete_func])
+
+  # pylint: disable=useless-super-delegation
+  def convert(self):
+    """Converts a TensorFlow GraphDef based on instance variables.
+
+    Returns:
+      The converted data in serialized format.
+
+    Raises:
+      ValueError:
+        No concrete functions is specified.
+        Multiple concrete functions are specified.
+        Input shape is not specified.
+        Invalid quantization parameters.
+    """
+    return super(TFLiteConverterV2, self).convert()
+
+
+class TFLiteConverterBaseV1(TFLiteConverterBase):
+  """Converter subclass to share functionality between V1 converters.
+
+  Attributes:
+    inference_type: Target data type of real-number arrays in the output file.
+      Must be `{tf.float32, tf.uint8}`. If `optimzations` are provided, this
+      parameter is ignored. (default tf.float32)
+    inference_input_type: Target data type of real-number input arrays. Allows
+      for a different type for input arrays. If an integer type is provided and
+      `optimizations` are not used, `quantized_inputs_stats` must be provided.
+      If `inference_type` is tf.uint8, signaling conversion to a fully quantized
+      model from a quantization-aware trained input model, then
+      `inference_input_type` defaults to tf.uint8. In all other cases,
+      `inference_input_type` defaults to tf.float32. Must be `{tf.float32,
+      tf.uint8, tf.int8}`
+    inference_output_type: Target data type of real-number output arrays. Allows
+      for a different type for output arrays. If `inference_type` is tf.uint8,
+      signaling conversion to a fully quantized model from a quantization-aware
+      trained output model, then `inference_output_type` defaults to tf.uint8.
+      In all other cases, `inference_output_type` must be tf.float32, an error
+      will be thrown otherwise. Must be `{tf.float32, tf.uint8, tf.int8}`
+    output_format: Output file format. Currently must be `{TFLITE,
+      GRAPHVIZ_DOT}`. (default TFLITE)
+    quantized_input_stats: Dict of strings representing input tensor names
+      mapped to tuple of floats representing the mean and standard deviation
+      of the training data (e.g., {"foo" : (0., 1.)}). Only need if
+        `inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
+        (quantized_input_value - mean_value) / std_dev_value. (default {})
+    default_ranges_stats: Tuple of integers representing (min, max) range values
+      for all arrays without a specified range. Intended for experimenting with
+      quantization via "dummy quantization". (default None)
+    drop_control_dependency: Boolean indicating whether to drop control
+      dependencies silently. This is due to TFLite not supporting control
+      dependencies. (default True)
+    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
+      nodes in unexpected locations. Used when the location of the FakeQuant
+      nodes is preventing graph transformations necessary to convert the graph.
+      Results in a graph that differs from the quantized training graph,
+      potentially causing differing arithmetic behavior. (default False)
+    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
+      inputs and outputs of the concat operator for quantized models. Changes
+      the ranges of concat operator overlap when true. (default False)
+    allow_custom_ops: Boolean indicating whether to allow custom operations.
+      When false any unknown operation is an error. When true, custom ops are
+      created for any op that is unknown. The developer will need to provide
+      these to the TensorFlow Lite runtime with a custom resolver. (default
+      False)
+    post_training_quantize: Deprecated. Please specify `[Optimize.DEFAULT]` for
+      `optimizations` instead. Boolean indicating whether to quantize the
+      weights of the converted float model.  Model size will be reduced and
+      there will be latency improvements (at the cost of accuracy). (default
+      False)
+    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
+      stages of processing GraphViz .dot files. Preferred over
+      --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
+      output file. (default None)
+    dump_graphviz_video: Boolean indicating whether to dump the graph after
+      every graph transformation. (default False)
+    conversion_summary_dir: A string indicating the path to the generated
+      conversion logs.
+    target_ops: Deprecated. Please specify `target_spec.supported_ops` instead.
+      Set of OpsSet options indicating which converter to use. (default
+      set([OpsSet.TFLITE_BUILTINS]))
+    target_spec: Experimental flag, subject to change. Specification of target
+      device.
+    optimizations: Experimental flag, subject to change. A list of optimizations
+      to apply when converting the model. E.g. `[Optimize.DEFAULT]`
+    representative_dataset: A representative dataset that can be used to
+      generate input and output samples for the model. The converter can use the
+      dataset to evaluate different optimizations.
+    experimental_new_converter: Experimental flag, subject to change. Enables
+      MLIR-based conversion instead of TOCO conversion.
+  """
+
+  def __init__(self, experimental_debug_info_func):
+    """Constructor for TFLiteConverter.
+
+    Args:
+      experimental_debug_info_func: An experimental function to retrieve the
+        graph debug info for a set of nodes from the `graph_def`.
+    """
+    super(TFLiteConverterBaseV1, self).__init__()
+    self.inference_type = constants.FLOAT
+    self.inference_input_type = None
+    self.inference_output_type = None
+    self.output_format = constants.TFLITE
+    self.quantized_input_stats = {}
+    self.default_ranges_stats = None
+    self.drop_control_dependency = True
+    self.reorder_across_fake_quant = False
+    self.change_concat_input_ranges = False
+    self.dump_graphviz_dir = None
+    self.dump_graphviz_video = False
+    self.conversion_summary_dir = None
+    self._debug_info_func = experimental_debug_info_func
+    self._custom_opdefs = None
+
+  def __setattr__(self, name, value):
+    if name == "post_training_quantize":
+      warnings.warn("Property %s is deprecated, "
+                    "please use optimizations=[Optimize.DEFAULT]"
+                    " instead." % name)
+      if value:
+        self.optimizations = [Optimize.DEFAULT]
+      else:
+        self.optimizations = []
+      return
+    if name == "target_ops":
+      warnings.warn("Property %s is deprecated, please use "
+                    "target_spec.supported_ops instead." % name)
+      self.target_spec.supported_ops = value
+      return
+    object.__setattr__(self, name, value)
+
+  def __getattribute__(self, name):
+    if name == "post_training_quantize":
+      warnings.warn("Property %s is deprecated, "
+                    "please use optimizations=[Optimize.DEFAULT]"
+                    " instead." % name)
+      return Optimize.DEFAULT in set(self.optimizations)
+    if name == "target_ops":
+      warnings.warn("Property %s is deprecated, please use "
+                    "target_spec.supported_ops instead." % name)
+      return self.target_spec.supported_ops
+    return object.__getattribute__(self, name)
+
+  def _validate_quantized_input_stats(self, converter_kwargs):
+    """Ensure quantized_input_stats provided if required."""
+
+    quantized_types = frozenset({constants.INT8, constants.QUANTIZED_UINT8})
+
+    requires_quantized_input_stats = (
+        (converter_kwargs["inference_type"] in quantized_types or
+         converter_kwargs["inference_input_type"] in quantized_types) and
+        not converter_kwargs["post_training_quantize"])
+
+    if (requires_quantized_input_stats and
+        not converter_kwargs["quantized_input_stats"]):
+      raise ValueError("std_dev and mean must be defined when inference_type "
+                       "or inference_input_type is QUANTIZED_UINT8 or INT8.")
+
+  def _convert(self):
+    """Converts a TensorFlow GraphDef based on instance variables.
+
+    Returns:
+      The converted data in serialized format. Either a TFLite Flatbuffer or a
+      Graphviz graph depending on value in `output_format`.
+
+    Raises:
+      ValueError:
+        Input shape is not specified.
+        None value for dimension in input_tensor.
+    """
+    quant_mode = QuantizationMode(self.optimizations, self.target_spec,
+                                  self.representative_dataset, self._graph_def)
+
+    if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()):
+      # Checks dimensions in input tensor.
+      for tensor in self._input_tensors:
+        shape = tensor.shape
+        if not shape:
+          raise ValueError("Provide an input shape for input array "
+                           "'{0}'.".format(_get_tensor_name(tensor)))
+        # Note that shape_list might be empty for scalar shapes.
+        shape_list = shape.as_list()
+        if None in shape_list[1:]:
+          raise ValueError(
+              "None is only supported in the 1st dimension. Tensor '{0}' has "
+              "invalid shape '{1}'.".format(
+                  _get_tensor_name(tensor), shape_list))
+        elif shape_list and shape_list[0] is None:
+          self._set_batch_size(batch_size=1)
+
+    # Get quantization stats. Ensures there is one stat per name if the stats
+    # are specified.
+    if self.quantized_input_stats:
+      quantized_stats = []
+      invalid_stats = []
+      for name in self.get_input_arrays():
+        if name in self.quantized_input_stats:
+          quantized_stats.append(self.quantized_input_stats[name])
+        else:
+          invalid_stats.append(name)
+
+      if invalid_stats:
+        raise ValueError("Quantization input stats are not available for input "
+                         "tensors '{0}'.".format(",".join(invalid_stats)))
+    else:
+      quantized_stats = None
+
+    toco_inference_input_type = self.inference_input_type
+    inference_input_type = self.inference_input_type
+    inference_output_type = self.inference_output_type
+    post_training_optimize = (
+        quant_mode.post_training_int8_no_float() or
+        quant_mode.post_training_int8_allow_float() or
+        quant_mode.post_training_dynamic_range_int8() or
+        quant_mode.post_training_fp16())
+    if post_training_optimize:
+      # Post training optimizations require that TOCO outputs a float model.
+      if self.inference_type != constants.FLOAT:
+        raise ValueError(
+            "`optimizations` require that `inference_type` is set to float.")
+      toco_inference_input_type = constants.FLOAT
+      # Set up default values.
+      if inference_input_type is None:
+        inference_input_type = constants.FLOAT
+      if inference_output_type is None:
+        inference_output_type = constants.FLOAT
+
+    weight_only_quantize = (
+        quant_mode.post_training_dynamic_range_int8() or
+        quant_mode.post_training_fp16())
+    if weight_only_quantize:
+      # Currently, weight only quantization requires float inputs and outputs.
+      if (inference_input_type != constants.FLOAT or
+          inference_output_type != constants.FLOAT):
+        raise ValueError(
+            "Provide an inference_input_type and inference_output_type of type "
+            "tf.float32.")
+
+    if not post_training_optimize and self.inference_output_type is not None:
+      raise ValueError(
+          "inference_output_type is currently not supported if optimizations "
+          "are not enabled.")
+
+    optimized_graph = self._graph_def
+    if not self.saved_model_dir:
+      # if it is not uint8 or int8 with post-training quantization, it is not
+      # quantization aware training, then graph optimization is applied.
+      # Graph optimization is disabled for quantization aware training.
+      if (self.inference_type != constants.QUANTIZED_UINT8 or
+          (self.inference_type == constants.INT8 and
+           (post_training_optimize or weight_only_quantize))):
+        try:
+          # TODO(b/150163103): Merge `disabling lower using switch merge' calls.
+          # Grappler will also try to lower while loop into switch merge
+          # representation which is undesired for Ophints, so we simply remove
+          # those attributes to prevent Grappler from doing so.
+          graph_def = _convert_to_constants.disable_lower_using_switch_merge(
+              optimized_graph)
+          # Run function inlining optimization to ensure any models generated
+          # through the from_frozen_graph path have been inlined.
+          optimized_graph = _run_graph_optimizations(
+              graph_def,
+              self._input_tensors,
+              self._output_tensors,
+              config=self._grappler_config(["function"]))
+        except Exception:  # pylint: disable=broad-except
+          optimized_graph = self._graph_def
+
+    self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
+
+    converter_kwargs = self._get_base_converter_args()
+
+    if quant_mode.post_training_dynamic_range_int8():
+      converter_kwargs.update({
+          "post_training_quantize": True,
+      })
+    elif quant_mode.post_training_fp16():
+      converter_kwargs.update({
+          "post_training_quantize": True,
+          "quantize_to_float16": True,
+      })
+
+    converter_kwargs.update({
+        "inference_type": self.inference_type,
+        "inference_input_type": toco_inference_input_type,
+        "output_format": self.output_format,
+        "quantized_input_stats": quantized_stats,
+        "default_ranges_stats": self.default_ranges_stats,
+        "drop_control_dependency": self.drop_control_dependency,
+        "reorder_across_fake_quant": self.reorder_across_fake_quant,
+        "change_concat_input_ranges": self.change_concat_input_ranges,
+        "dump_graphviz_dir": self.dump_graphviz_dir,
+        "dump_graphviz_video": self.dump_graphviz_video,
+        "conversion_summary_dir": self.conversion_summary_dir,
+        "custom_opdefs": self._custom_opdefs,
+    })
+
+    if not self.experimental_new_converter:
+      logging.warning(
+          "Please consider switching to use new converter by setting "
+          "experimental_new_converter to true. "
+          "Old converter (TOCO) is deprecated and flow will be switched on "
+          "by default to use new converter soon.")
+    else:
+      logging.info("Using experimental converter: If you encountered a problem "
+                   "please file a bug. You can opt-out "
+                   "by setting experimental_new_converter=False")
+
+    self._validate_quantized_input_stats(converter_kwargs)
+
+    # Converts model.
+    if self._has_valid_tensors():
+      result = _toco_convert_impl(
+          input_data=optimized_graph,
+          input_tensors=self._input_tensors,
+          output_tensors=self._output_tensors,
+          **converter_kwargs)
+    else:
+      result = _toco_convert_graph_def(
+          input_data=optimized_graph,
+          input_arrays_with_shape=self._input_arrays_with_shape,
+          output_arrays=self._output_arrays,
+          **converter_kwargs)
+
+    if quant_mode.post_training_int8_no_float():
+      result = self._calibrate_quantize_model(result, inference_input_type,
+                                              inference_output_type, False)
+    elif quant_mode.post_training_int8_allow_float():
+      result = self._calibrate_quantize_model(result, inference_input_type,
+                                              inference_output_type, True)
+
+    if self._experimental_sparsify_model:
+      result = _mlir_sparsify(result)
+
+    return result
+
+  def get_input_arrays(self):
+    """Returns a list of the names of the input tensors.
+
+    Returns:
+      List of strings.
+    """
+    if self._has_valid_tensors():
+      return [_get_tensor_name(tensor) for tensor in self._input_tensors]
+    else:
+      return [name for name, _ in self._input_arrays_with_shape]
+
+  def _has_valid_tensors(self):
+    """Checks if the input and output tensors have been initialized.
+
+    Returns:
+      Bool.
+    """
+    return self._input_tensors and self._output_tensors
+
+  def _set_batch_size(self, batch_size):
+    """Sets the first dimension of the input tensor to `batch_size`.
+
+    Args:
+      batch_size: Batch size for the model. Replaces the first dimension of an
+        input size array if undefined. (default 1)
+
+    Raises:
+      ValueError: input_tensor is not defined.
+    """
+    if not self._has_valid_tensors():
+      raise ValueError("The batch size cannot be set for this model. Please "
+                       "use input_shapes parameter.")
+
+    for tensor in self._input_tensors:
+      shape = tensor.shape.as_list()
+      if shape[0] is None:
+        shape[0] = batch_size
+        tensor.set_shape(shape)
+
+  def _is_unknown_shapes_allowed(self):
+    # Ophint Converted nodes will need the shapes to be known.
+    if _is_ophint_converted(self._graph_def):
+      return False
+
+    if not super(TFLiteConverterBaseV1, self)._is_unknown_shapes_allowed():
+      return False
+
+    # `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by
+    # the MLIR converter.
+    if self.conversion_summary_dir:
+      logging.warning(
+          "`conversion_summary_dir` does not work with unknown shapes. "
+          "Graphs with unknown shapes might be different than when this flag "
+          "is disabled.")
+      return False
+    return True
+
+
+class TFLiteSavedModelConverter(TFLiteConverterBaseV1):
+  """Converts the given SavedModel into TensorFlow Lite model.
+
+  Attributes:
+      saved_model_dir: Directory of the SavedModel.
+  """
+
+  def __init__(self,
+               saved_model_dir,
+               saved_model_tags,
+               saved_model_exported_names,
+               experimental_debug_info_func=None):
+    """Constructor for TFLiteConverter.
+
+    Args:
+      saved_model_dir: Directory of the SavedModel.
+      saved_model_tags: Set of tags identifying the MetaGraphDef within the
+        SavedModel to analyze. All tags in the tag set must be present. (default
+        set(SERVING)).
+      saved_model_exported_names: Names to be exported (default: export all)
+        when the saved model import path is on.
+      experimental_debug_info_func: An experimental function to retrieve the
+        graph debug info for a set of nodes from the `graph_def`.
+
+    Raises:
+      ValueError: Invalid arguments.
+    """
+    super(TFLiteSavedModelConverter,
+          self).__init__(experimental_debug_info_func)
+    self.saved_model_dir = saved_model_dir
+    self._saved_model_tags = saved_model_tags
+    self._saved_model_exported_names = saved_model_exported_names
+
+    signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+
+    if len(self._saved_model_exported_names) != 1:
+      raise ValueError("Only support a single signature key.")
+
+    signature_key = self._saved_model_exported_names[0]
+
+    result = _freeze_saved_model(self.saved_model_dir, None, None, None,
+                                 self._saved_model_tags, signature_key)
+    self._graph_def = result[0]
+    self._input_tensors = result[1]
+    self._output_tensors = result[2]
+    self._parse_saved_model_args()
+
+  def convert(self):
+    """Converts a TensorFlow GraphDef based on instance variables.
+
+    Returns:
+      The converted data in serialized format. Either a TFLite Flatbuffer or a
+      Graphviz graph depending on value in `output_format`.
+
+    Raises:
+      ValueError:
+        Input shape is not specified.
+        None value for dimension in input_tensor.
+    """
+    return self._convert()
+
+
+class TFLiteFrozenGraphConverter(TFLiteConverterBaseV1):
+  """Converts the given frozen graph def into TensorFlow Lite model."""
+
+  def __init__(self,
+               graph_def,
+               input_tensors,
+               output_tensors,
+               input_arrays_with_shape=None,
+               output_arrays=None,
+               experimental_debug_info_func=None):
+    """Constructor for TFLiteConverter.
+
+    Args:
+      graph_def: Frozen TensorFlow GraphDef.
+      input_tensors: List of input tensors. Type and shape are computed using
+        `foo.shape` and `foo.dtype`.
+      output_tensors: List of output tensors (only .name is used from this).
+      input_arrays_with_shape: Tuple of strings representing input tensor names
+        and list of integers representing input shapes
+        (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
+          into TensorFlow and when `input_tensors` and `output_tensors` are
+          None. (default None)
+      output_arrays: List of output tensors to freeze graph with. Use only when
+        graph cannot be loaded into TensorFlow and when `input_tensors` and
+        `output_tensors` are None. (default None)
+      experimental_debug_info_func: An experimental function to retrieve the
+        graph debug info for a set of nodes from the `graph_def`.
+
+    Raises:
+      ValueError: Invalid arguments.
+    """
+    super(TFLiteFrozenGraphConverter,
+          self).__init__(experimental_debug_info_func)
+    self._graph_def = graph_def
+    self._input_tensors = input_tensors
+    self._output_tensors = output_tensors
+
+    # Attributes are used by models that cannot be loaded into TensorFlow.
+    if not self._has_valid_tensors():
+      if not input_arrays_with_shape or not output_arrays:
+        raise ValueError(
+            "If input_tensors and output_tensors are None, both "
+            "input_arrays_with_shape and output_arrays must be defined.")
+      self._input_arrays_with_shape = input_arrays_with_shape
+      self._output_arrays = output_arrays
+
+  def convert(self):
+    """Converts a TensorFlow GraphDef based on instance variables.
+
+    Returns:
+      The converted data in serialized format. Either a TFLite Flatbuffer or a
+      Graphviz graph depending on value in `output_format`.
+
+    Raises:
+      ValueError:
+        Input shape is not specified.
+        None value for dimension in input_tensor.
+    """
+    return self._convert()
+
+
 @_tf_export(v1=["lite.TFLiteConverter"])
-class TFLiteConverter(TFLiteConverterBase):
+class TFLiteConverter(TFLiteFrozenGraphConverter):
   """Convert a TensorFlow model into `output_format`.
 
   This is used to convert from a TensorFlow GraphDef, SavedModel or tf.keras
@@ -831,15 +1426,14 @@ class TFLiteConverter(TFLiteConverterBase):
     ```
   """
 
+  # pylint: disable=useless-super-delegation
   def __init__(self,
                graph_def,
                input_tensors,
                output_tensors,
                input_arrays_with_shape=None,
                output_arrays=None,
-               experimental_debug_info_func=None,
-               saved_model_dir=None,
-               saved_model_tags=None):
+               experimental_debug_info_func=None):
     """Constructor for TFLiteConverter.
 
     Args:
@@ -857,47 +1451,14 @@ class TFLiteConverter(TFLiteConverterBase):
         `output_tensors` are None. (default None)
       experimental_debug_info_func: An experimental function to retrieve the
         graph debug info for a set of nodes from the `graph_def`.
-      saved_model_dir: Directory of the SavedModel. This argument can be null
-        when it creates via the from_keras_model and from_concrete_function
-        methods.
-      saved_model_tags: Set of tags identifying the MetaGraphDef within the
-        SavedModel to analyze. All tags in the tag set must be present. (default
-        set(SERVING)).  This argument will be available when the saved model dir
-        argument is set.
 
     Raises:
       ValueError: Invalid arguments.
     """
-    super(TFLiteConverter, self).__init__()
-    self._graph_def = graph_def
-    self._input_tensors = input_tensors
-    self._output_tensors = output_tensors
-    self.inference_type = constants.FLOAT
-    self.inference_input_type = None
-    self.inference_output_type = None
-    self.output_format = constants.TFLITE
-    self.quantized_input_stats = {}
-    self.default_ranges_stats = None
-    self.drop_control_dependency = True
-    self.reorder_across_fake_quant = False
-    self.change_concat_input_ranges = False
-    self._post_training_quantize = False
-    self.dump_graphviz_dir = None
-    self.dump_graphviz_video = False
-    self.conversion_summary_dir = None
-    self._debug_info_func = experimental_debug_info_func
-    self._custom_opdefs = None
-    self._saved_model_dir = saved_model_dir
-    self._saved_model_tags = saved_model_tags
-
-    # Attributes are used by models that cannot be loaded into TensorFlow.
-    if not self._has_valid_tensors():
-      if not input_arrays_with_shape or not output_arrays:
-        raise ValueError(
-            "If input_tensors and output_tensors are None, both "
-            "input_arrays_with_shape and output_arrays must be defined.")
-      self._input_arrays_with_shape = input_arrays_with_shape
-      self._output_arrays = output_arrays
+    super(TFLiteConverter,
+          self).__init__(graph_def, input_tensors, output_tensors,
+                         input_arrays_with_shape, output_arrays,
+                         experimental_debug_info_func)
 
   @classmethod
   def from_session(cls, sess, input_tensors, output_tensors):
@@ -1045,15 +1606,19 @@ class TFLiteConverter(TFLiteConverterBase):
     if signature_key is None:
       signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
 
+    saved_model_converter = TFLiteSavedModelConverter(saved_model_dir, tag_set,
+                                                      [signature_key])
+    if saved_model_converter.saved_model_dir:
+      return saved_model_converter
+
     result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
                                  output_arrays, tag_set, signature_key)
+
     return cls(
         graph_def=result[0],
         input_tensors=result[1],
         output_tensors=result[2],
-        experimental_debug_info_func=_build_debug_info_func(result[3]),
-        saved_model_dir=saved_model_dir,
-        saved_model_tags=tag_set)
+        experimental_debug_info_func=_build_debug_info_func(result[3]))
 
   @classmethod
   def from_keras_model_file(cls,
@@ -1128,50 +1693,7 @@ class TFLiteConverter(TFLiteConverterBase):
         output_tensors,
         experimental_debug_info_func=_build_debug_info_func(sess.graph))
 
-  def __setattr__(self, name, value):
-    if name == "post_training_quantize":
-      warnings.warn("Property %s is deprecated, "
-                    "please use optimizations=[Optimize.DEFAULT]"
-                    " instead." % name)
-      if value:
-        self.optimizations = [Optimize.DEFAULT]
-      else:
-        self.optimizations = []
-      return
-    if name == "target_ops":
-      warnings.warn("Property %s is deprecated, please use "
-                    "target_spec.supported_ops instead." % name)
-      self.target_spec.supported_ops = value
-      return
-    object.__setattr__(self, name, value)
-
-  def __getattribute__(self, name):
-    if name == "post_training_quantize":
-      warnings.warn("Property %s is deprecated, "
-                    "please use optimizations=[Optimize.DEFAULT]"
-                    " instead." % name)
-      return Optimize.DEFAULT in set(self.optimizations)
-    if name == "target_ops":
-      warnings.warn("Property %s is deprecated, please use "
-                    "target_spec.supported_ops instead." % name)
-      return self.target_spec.supported_ops
-    return object.__getattribute__(self, name)
-
-  def _validate_quantized_input_stats(self, converter_kwargs):
-    """Ensure quantized_input_stats provided if required."""
-
-    quantized_types = frozenset({constants.INT8, constants.QUANTIZED_UINT8})
-
-    requires_quantized_input_stats = (
-        (converter_kwargs["inference_type"] in quantized_types or
-         converter_kwargs["inference_input_type"] in quantized_types) and
-        not converter_kwargs["post_training_quantize"])
-
-    if (requires_quantized_input_stats and
-        not converter_kwargs["quantized_input_stats"]):
-      raise ValueError("std_dev and mean must be defined when inference_type "
-                       "or inference_input_type is QUANTIZED_UINT8 or INT8.")
-
+  # pylint: disable=useless-super-delegation
   def convert(self):
     """Converts a TensorFlow GraphDef based on instance variables.
 
@@ -1184,231 +1706,7 @@ class TFLiteConverter(TFLiteConverterBase):
         Input shape is not specified.
         None value for dimension in input_tensor.
     """
-    # Parses SavedModel argument.
-    self._parse_saved_model_args()
-
-    quant_mode = QuantizationMode(self.optimizations, self.target_spec,
-                                  self.representative_dataset, self._graph_def)
-
-    # Checks dimensions in input tensor.
-    if (not self._is_unknown_shapes_allowed() and self._has_valid_tensors()):
-      for tensor in self._input_tensors:
-        shape = tensor.shape
-        if not shape:
-          raise ValueError("Provide an input shape for input array "
-                           "'{0}'.".format(_get_tensor_name(tensor)))
-        # Note that shape_list might be empty for scalar shapes.
-        shape_list = shape.as_list()
-        if None in shape_list[1:]:
-          raise ValueError(
-              "None is only supported in the 1st dimension. Tensor '{0}' has "
-              "invalid shape '{1}'.".format(
-                  _get_tensor_name(tensor), shape_list))
-        elif shape_list and shape_list[0] is None:
-          self._set_batch_size(batch_size=1)
-
-    # Get quantization stats. Ensures there is one stat per name if the stats
-    # are specified.
-    if self.quantized_input_stats:
-      quantized_stats = []
-      invalid_stats = []
-      for name in self.get_input_arrays():
-        if name in self.quantized_input_stats:
-          quantized_stats.append(self.quantized_input_stats[name])
-        else:
-          invalid_stats.append(name)
-
-      if invalid_stats:
-        raise ValueError("Quantization input stats are not available for input "
-                         "tensors '{0}'.".format(",".join(invalid_stats)))
-    else:
-      quantized_stats = None
-
-    toco_inference_input_type = self.inference_input_type
-    inference_input_type = self.inference_input_type
-    inference_output_type = self.inference_output_type
-    post_training_optimize = (
-        quant_mode.post_training_int8_no_float() or
-        quant_mode.post_training_int8_allow_float() or
-        quant_mode.post_training_dynamic_range_int8() or
-        quant_mode.post_training_fp16())
-    if post_training_optimize:
-      # Post training optimizations require that TOCO outputs a float model.
-      if self.inference_type != constants.FLOAT:
-        raise ValueError(
-            "`optimizations` require that `inference_type` is set to float.")
-      toco_inference_input_type = constants.FLOAT
-      # Set up default values.
-      if inference_input_type is None:
-        inference_input_type = constants.FLOAT
-      if inference_output_type is None:
-        inference_output_type = constants.FLOAT
-
-    weight_only_quantize = (
-        quant_mode.post_training_dynamic_range_int8() or
-        quant_mode.post_training_fp16())
-    if weight_only_quantize:
-      # Currently, weight only quantization requires float inputs and outputs.
-      if (inference_input_type != constants.FLOAT or
-          inference_output_type != constants.FLOAT):
-        raise ValueError(
-            "Provide an inference_input_type and inference_output_type of type "
-            "tf.float32.")
-
-    if not post_training_optimize and self.inference_output_type is not None:
-      raise ValueError(
-          "inference_output_type is currently not supported if optimizations "
-          "are not enabled.")
-
-    optimized_graph = self._graph_def
-    if not self._saved_model_dir:
-      # if it is not uint8 or int8 with post-training quantization, it is not
-      # quantization aware training, then graph optimization is applied.
-      # Graph optimization is disabled for quantization aware training.
-      if (self.inference_type != constants.QUANTIZED_UINT8 or
-          (self.inference_type == constants.INT8 and
-           (post_training_optimize or weight_only_quantize))):
-        try:
-          # TODO(b/150163103): Merge `disabling lower using switch merge' calls.
-          # Grappler will also try to lower while loop into switch merge
-          # representation which is undesired for Ophints, so we simply remove
-          # those attributes to prevent Grappler from doing so.
-          graph_def = _convert_to_constants.disable_lower_using_switch_merge(
-              optimized_graph)
-          # Run function inlining optimization to ensure any models generated
-          # through the from_frozen_graph path have been inlined.
-          optimized_graph = _run_graph_optimizations(
-              graph_def,
-              self._input_tensors,
-              self._output_tensors,
-              config=self._grappler_config(["function"]))
-        except Exception:
-          optimized_graph = self._graph_def
-
-    self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
-
-    converter_kwargs = self._get_base_converter_args()
-
-    if quant_mode.post_training_dynamic_range_int8():
-      converter_kwargs.update({
-          "post_training_quantize": True,
-      })
-    elif quant_mode.post_training_fp16():
-      converter_kwargs.update({
-          "post_training_quantize": True,
-          "quantize_to_float16": True,
-      })
-
-    converter_kwargs.update({
-        "inference_type": self.inference_type,
-        "inference_input_type": toco_inference_input_type,
-        "output_format": self.output_format,
-        "quantized_input_stats": quantized_stats,
-        "default_ranges_stats": self.default_ranges_stats,
-        "drop_control_dependency": self.drop_control_dependency,
-        "reorder_across_fake_quant": self.reorder_across_fake_quant,
-        "change_concat_input_ranges": self.change_concat_input_ranges,
-        "dump_graphviz_dir": self.dump_graphviz_dir,
-        "dump_graphviz_video": self.dump_graphviz_video,
-        "conversion_summary_dir": self.conversion_summary_dir,
-        "custom_opdefs": self._custom_opdefs,
-    })
-
-    if not self.experimental_new_converter:
-      logging.warning(
-          "Please consider switching to use new converter by setting "
-          "experimental_new_converter to true. "
-          "Old converter (TOCO) is deprecated and flow will be switched on "
-          "by default to use new converter soon.")
-    else:
-      logging.info("Using experimental converter: If you encountered a problem "
-                   "please file a bug. You can opt-out "
-                   "by setting experimental_new_converter=False")
-
-    self._validate_quantized_input_stats(converter_kwargs)
-
-    # Converts model.
-    if self._has_valid_tensors():
-      result = _toco_convert_impl(
-          input_data=optimized_graph,
-          input_tensors=self._input_tensors,
-          output_tensors=self._output_tensors,
-          **converter_kwargs)
-    else:
-      result = _toco_convert_graph_def(
-          input_data=optimized_graph,
-          input_arrays_with_shape=self._input_arrays_with_shape,
-          output_arrays=self._output_arrays,
-          **converter_kwargs)
-
-    if quant_mode.post_training_int8_no_float():
-      result = self._calibrate_quantize_model(result, inference_input_type,
-                                              inference_output_type, False)
-    elif quant_mode.post_training_int8_allow_float():
-      result = self._calibrate_quantize_model(result, inference_input_type,
-                                              inference_output_type, True)
-
-    if self._experimental_sparsify_model:
-      result = _mlir_sparsify(result)
-
-    return result
-
-  def get_input_arrays(self):
-    """Returns a list of the names of the input tensors.
-
-    Returns:
-      List of strings.
-    """
-    if self._has_valid_tensors():
-      return [_get_tensor_name(tensor) for tensor in self._input_tensors]
-    else:
-      return [name for name, _ in self._input_arrays_with_shape]
-
-  def _has_valid_tensors(self):
-    """Checks if the input and output tensors have been initialized.
-
-    Returns:
-      Bool.
-    """
-    return self._input_tensors and self._output_tensors
-
-  def _set_batch_size(self, batch_size):
-    """Sets the first dimension of the input tensor to `batch_size`.
-
-    Args:
-      batch_size: Batch size for the model. Replaces the first dimension of an
-        input size array if undefined. (default 1)
-
-    Raises:
-      ValueError: input_tensor is not defined.
-    """
-    if not self._has_valid_tensors():
-      raise ValueError("The batch size cannot be set for this model. Please "
-                       "use input_shapes parameter.")
-
-    for tensor in self._input_tensors:
-      shape = tensor.shape.as_list()
-      if shape[0] is None:
-        shape[0] = batch_size
-        tensor.set_shape(shape)
-
-  def _is_unknown_shapes_allowed(self):
-    # Ophint Converted nodes will need the shapes to be known.
-    if _is_ophint_converted(self._graph_def):
-      return False
-
-    if not super(TFLiteConverter, self)._is_unknown_shapes_allowed():
-      return False
-
-    # `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by
-    # the MLIR converter.
-    if self.conversion_summary_dir:
-      logging.warning(
-          "`conversion_summary_dir` does not work with unknown shapes. "
-          "Graphs with unknown shapes might be different than when this flag "
-          "is disabled.")
-      return False
-    return True
+    return super(TFLiteConverter, self).convert()
 
 
 @_tf_export(v1=["lite.TocoConverter"])
diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py
index 3ff9fbc3710..530c514eb96 100644
--- a/tensorflow/lite/python/lite_test.py
+++ b/tensorflow/lite/python/lite_test.py
@@ -2318,5 +2318,42 @@ class ImportOpsUtilTest(LiteTest):
     self.assertIsNotNone(lite.get_potentially_supported_ops())
 
 
+class DefaultConverterAttrsTest(LiteTest):
+
+  def testAttrs(self):
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
+
+    # Convert model.
+    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+                                                  [out_tensor])
+
+    # Assert output format.
+    self.assertEqual(converter.output_format, lite_constants.TFLITE)
+
+    # Assert the default inference type is float.
+    self.assertEqual(converter.inference_type, lite_constants.FLOAT)
+
+    # Assert the default inference type overrides are None.
+    self.assertIsNone(converter.inference_input_type)
+    self.assertIsNone(converter.inference_output_type)
+
+    # Assert the default quantization options are not set.
+    self.assertEqual(converter.quantized_input_stats, {})
+    self.assertIsNone(converter.default_ranges_stats)
+    self.assertFalse(converter.reorder_across_fake_quant)
+    self.assertFalse(converter.change_concat_input_ranges)
+
+    # Assert dropping control dependency is enabled by default.
+    self.assertTrue(converter.drop_control_dependency)
+
+    # Assert dumping extra information is disabled by default.
+    self.assertIsNone(converter.dump_graphviz_dir)
+    self.assertFalse(converter.dump_graphviz_video)
+    self.assertIsNone(converter.conversion_summary_dir)
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index d351fd492f6..59f326d4b9f 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -981,6 +981,27 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
     np.testing.assert_almost_equal(
         expected_value.numpy(), actual_value[0], decimal=4)
 
+  def testSizeInvalid(self):
+
+    @tf.function(input_signature=[
+        tf.TensorSpec(shape=[1, None, 16, 3], dtype=tf.float32)
+    ])
+    def model(in_tensor):
+      return in_tensor + in_tensor
+
+    concrete_func = model.get_concrete_function()
+
+    # Test invalid shape. None after 1st dimension. Run with TOCO in order to
+    # invoke shape checking code.
+    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
+    converter.experimental_new_converter = False
+    with self.assertRaises(ValueError) as error:
+      converter.convert()
+    self.assertEqual(
+        'None is only supported in the 1st dimension. Tensor '
+        '\'in_tensor\' has invalid shape \'[1, None, 16, 3]\'.',
+        str(error.exception))
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt
index 0c43fc556aa..e7689b4320f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-t-f-lite-converter.pbtxt
@@ -1,11 +1,13 @@
 path: "tensorflow.lite.TFLiteConverter"
 tf_class {
   is_instance: "<class \'tensorflow.lite.python.lite.TFLiteConverter\'>"
+  is_instance: "<class \'tensorflow.lite.python.lite.TFLiteFrozenGraphConverter\'>"
+  is_instance: "<class \'tensorflow.lite.python.lite.TFLiteConverterBaseV1\'>"
   is_instance: "<class \'tensorflow.lite.python.lite.TFLiteConverterBase\'>"
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "convert"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt
index c575283b74d..c8c163d2f2a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-t-f-lite-converter.pbtxt
@@ -1,11 +1,13 @@
 path: "tensorflow.lite.TFLiteConverter"
 tf_class {
   is_instance: "<class \'tensorflow.lite.python.lite.TFLiteConverterV2\'>"
+  is_instance: "<class \'tensorflow.lite.python.lite.TFLiteFrozenGraphConverterV2\'>"
+  is_instance: "<class \'tensorflow.lite.python.lite.TFLiteConverterBaseV2\'>"
   is_instance: "<class \'tensorflow.lite.python.lite.TFLiteConverterBase\'>"
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'funcs\', \'trackable_obj\', \'saved_model_dir\', \'saved_model_tags\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'funcs\', \'trackable_obj\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "convert"