diff --git a/RELEASE.md b/RELEASE.md
index 12b5168954b..7895a0ba113 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -56,6 +56,8 @@
 *   `tf.lite`:
     * Better support for ops with high-dimensional broadcasting inputs by adding
   `BroadcastTo` ops when necessary.
+    * `TFLiteConverter`:
+      * Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (tf.int8, tf.uint8) instead of defaulting to float type (tf.float32).
 *   `tf.random`:
     * <ADD RELEASE NOTES HERE>
 *   Math and Linear Algebra:
@@ -68,7 +70,7 @@
     * <ADD RELEASE NOTES HERE>
 *   Other:
     * We have replaced uses of "whitelist" and "blacklist" with "allowlist"
-  and "denylist" where possible. Please see 
+  and "denylist" where possible. Please see
   https://developers.google.com/style/word-list#blacklist for more context.
     * <ADD RELEASE NOTES HERE>
 
diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD
index e26000c810a..55a2a69675d 100644
--- a/tensorflow/lite/python/BUILD
+++ b/tensorflow/lite/python/BUILD
@@ -212,8 +212,11 @@ py_library(
     deps = [
         ":lite_constants",
         ":op_hint",
+        ":schema_py",
         "//tensorflow/python:tf_optimizer",
         "//tensorflow/python/eager:wrap_function",
+        "@absl_py//absl/logging",
+        "@flatbuffers//:runtime_py",
         "@six_archive//:six",
     ],
 )
@@ -224,12 +227,24 @@ py_test(
     python_version = "PY3",
     srcs_version = "PY2AND3",
     tags = [
+        "no_mac",
         "no_windows",
     ],
     deps = [
+        ":lite_constants",
         ":util",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:convert_to_constants",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
         "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:session",
+        "//third_party/py/numpy",
+        "@absl_py//absl/testing:parameterized",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py
index e919aa4b00f..a08b40bbed6 100644
--- a/tensorflow/lite/python/lite.py
+++ b/tensorflow/lite/python/lite.py
@@ -61,6 +61,7 @@ from tensorflow.lite.python.util import get_grappler_config as _get_grappler_con
 from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
 from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
 from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
+from tensorflow.lite.python.util import modify_integer_quantized_model_io_type as _modify_integer_quantized_model_io_type
 from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
 from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
 from tensorflow.python import keras as _keras
@@ -314,6 +315,23 @@ class QuantizationMode(object):
     else:
       return False, None
 
+  def flags_modify_model_io_type(
+      self, input_type=constants.FLOAT, output_type=constants.FLOAT):
+    """Flags for modifying the input and output type of a tflite model."""
+    is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
+    is_training_time_only_quantize = self.training_time_int8_allow_float() and \
+        not is_post_training_quantize
+
+    # TODO(b/153576658): Consolidate post/during training quantization workflows
+    # to modify model input/output type after MLIR conversion.
+    if is_training_time_only_quantize:
+      return {
+          "inference_input_type": input_type,
+          "inference_output_type": output_type,
+      }
+    else:
+      return None
+
   # Below are helpers for the above functions.
 
   def _validate_int8_required(self):
@@ -557,9 +575,8 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
   def _validate_inference_input_output_types(self, quant_mode):
     """Validate inference_input_type and inference_output_type flags."""
     default_types = [constants.FLOAT, None]
-    # We only support integer types for post training integer quantization
-    # as we have statistical information to quantize the input and output.
-    if quant_mode.is_post_training_integer_quantize():
+    # We support integer input/output for integer quantized models only.
+    if quant_mode.training_time_int8_allow_float():
       all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8]
       if self.inference_input_type not in all_types or \
           self.inference_output_type not in all_types:
@@ -643,6 +660,12 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
     if calibrate_and_quantize:
       result = self._calibrate_quantize_model(result, **flags)
 
+    flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
+        self.inference_input_type, self.inference_output_type)
+    if flags_modify_model_io_type:
+      result = _modify_integer_quantized_model_io_type(
+          result, **flags_modify_model_io_type)
+
     if self._experimental_sparsify_model:
       result = _mlir_sparsify(result)
 
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index 6fab4fd6086..4093a9d5bb4 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -374,8 +374,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
 
     return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
 
+  @parameterized.named_parameters(
+      ('_DefaultFLOAT32InputOutput', lite.constants.FLOAT),
+      ('_INT8InputOutput', lite.constants.INT8),
+      ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
   @test_util.run_v2_only
-  def testTrainingTimeQuantization(self):
+  def testTrainingTimeQuantization(self, inference_input_output_type):
     model = self._getTrainingTimeQuantizedModel()
 
     float_converter = lite.TFLiteConverterV2.from_keras_model(model)
@@ -384,37 +388,24 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
 
     quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
     quantized_converter.optimizations = [lite.Optimize.DEFAULT]
+    quantized_converter.inference_input_type = inference_input_output_type
+    quantized_converter.inference_output_type = inference_input_output_type
     quantized_tflite = quantized_converter.convert()
     self.assertTrue(quantized_tflite)
 
-    # Ensure that the quantized weights tflite model is smaller.
-    self.assertLess(len(quantized_tflite), len(float_tflite))
-
     interpreter = Interpreter(model_content=quantized_tflite)
-    self.assertEqual(np.float32, interpreter.get_input_details()[0]['dtype'])
+    interpreter.allocate_tensors()
+    input_details = interpreter.get_input_details()
+    self.assertLen(input_details, 1)
+    self.assertEqual(inference_input_output_type.as_numpy_dtype,
+                     input_details[0]['dtype'])
+    output_details = interpreter.get_output_details()
+    self.assertLen(output_details, 1)
+    self.assertEqual(inference_input_output_type.as_numpy_dtype,
+                     output_details[0]['dtype'])
 
-  @parameterized.named_parameters(
-      ('_INT8InputOutput', lite.constants.INT8),
-      ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
-  def testInvalidTrainingTimeQuantization(self, inference_input_output_type):
-    # We currently don't support integer inference_input_type and
-    # inference_output_type flags for training time quantization.
-
-    model = self._getTrainingTimeQuantizedModel()
-
-    converter = lite.TFLiteConverterV2.from_keras_model(model)
-    tflite_model = converter.convert()
-    self.assertTrue(tflite_model)
-
-    quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
-    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
-    with self.assertRaises(ValueError) as error:
-      quantized_converter.inference_input_type = inference_input_output_type
-      quantized_converter.inference_output_type = inference_input_output_type
-      quantized_converter.convert()
-    self.assertEqual(
-        'The inference_input_type and inference_output_type '
-        'must be tf.float32.', str(error.exception))
+    # Ensure that the quantized tflite model is smaller.
+    self.assertLess(len(quantized_tflite), len(float_tflite))
 
   @test_util.run_v2_only
   def testNewQuantizer(self):
diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py
index ff7caad0f88..9f84681c12b 100644
--- a/tensorflow/lite/python/util.py
+++ b/tensorflow/lite/python/util.py
@@ -19,15 +19,21 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import copy
 import datetime
 import sys
 
+from absl import logging
+
 import six
 from six.moves import range
 
+from flatbuffers.python import flatbuffers
 from tensorflow.core.protobuf import config_pb2 as _config_pb2
 from tensorflow.core.protobuf import graph_debug_info_pb2
 from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
+from tensorflow.lite.python import lite_constants as _lite_constants
+from tensorflow.lite.python import schema_py_generated as _schema_fb
 from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
 from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
 from tensorflow.lite.toco import types_pb2 as _types_pb2
@@ -55,6 +61,25 @@ _MAP_TF_TO_TFLITE_TYPES = {
     dtypes.bool: _types_pb2.BOOL,
 }
 
+_MAP_TFLITE_ENUM_TO_TF_TYPES = {
+    0: dtypes.float32,
+    1: dtypes.float16,
+    2: dtypes.int32,
+    3: dtypes.uint8,
+    4: dtypes.int64,
+    5: dtypes.string,
+    6: dtypes.bool,
+    7: dtypes.int16,
+    8: dtypes.complex64,
+    9: dtypes.int8,
+    10: dtypes.float64,
+}
+
+_TFLITE_FILE_IDENTIFIER = b"TFL3"
+
+_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (_lite_constants.FLOAT, _lite_constants.INT8,
+                                    _lite_constants.QUANTIZED_UINT8)
+
 
 def convert_dtype_to_tflite_type(tf_dtype):
   """Converts tf.dtype to TFLite proto type.
@@ -74,6 +99,31 @@ def convert_dtype_to_tflite_type(tf_dtype):
   return result
 
 
+def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
+  """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
+
+  Args:
+    tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
+
+  Raises:
+    ValueError: If an invalid tflite enum type is provided.
+
+  Returns:
+    tf type (eg: tf.float32)
+  """
+  tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
+  if tf_type is None:
+    raise ValueError(
+        "Unsupported enum {}. The valid map of enum to tf.dtypes is : {}"
+        .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
+  return tf_type
+
+
+def _get_dtype_name(tf_type):
+  """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
+  return "tf." + tf_type.name
+
+
 def get_tensor_name(tensor):
   """Returns name of the input tensor.
 
@@ -514,3 +564,218 @@ extern const int {array_name}_len;
       license_text=license_text)
 
   return source_text, header_text
+
+
+def _convert_model_from_bytearray_to_object(model_bytearray):
+  """Converts a tflite model from a bytearray into a parsable object."""
+  model_object = _schema_fb.Model.GetRootAsModel(model_bytearray, 0)
+  model_object = _schema_fb.ModelT.InitFromObj(model_object)
+  model_object = copy.deepcopy(model_object)
+  model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0]
+  return model_object
+
+
+def _convert_model_from_object_to_bytearray(model_object):
+  """Converts a tflite model from a parsable object into a bytearray."""
+  # Initial size of the buffer, which will grow automatically if needed
+  builder = flatbuffers.Builder(1024)
+  model_offset = model_object.Pack(builder)
+  builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
+  return bytes(builder.Output())
+
+
+def _remove_tensors_from_model(model, remove_tensors_idxs):
+  """Remove tensors from model."""
+  if not remove_tensors_idxs:
+    return
+  if len(model.subgraphs) > 1:
+    raise ValueError("Model must only have one subgraph. Instead, it has "
+                     "{} subgraphs.".format(len(model.subgraphs)))
+  subgraph = model.subgraphs[0]
+  tensors = subgraph.tensors
+  operators = subgraph.operators
+
+  logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
+  # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
+  # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
+  if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
+    logging.debug("Removing tensors only at the end of the tensor list")
+    del tensors[min(remove_tensors_idxs):]
+  else:
+    logging.debug("Removing tensors requires updating the model")
+    # Map the old tensor indices to new tensor indices
+    d_old_to_new_tensors = {}
+    left_shift_by = 0
+    for idx in range(len(tensors)):
+      if idx in remove_tensors_idxs:
+        left_shift_by += 1
+      else:
+        d_old_to_new_tensors[idx] = idx - left_shift_by
+    logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
+    # Update tensor indices referenced throughout the model
+    def update_tensors(tensor_idxs):
+      for i, ti in enumerate(tensor_idxs):
+        tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
+    update_tensors(subgraph.inputs)
+    update_tensors(subgraph.outputs)
+    for op in operators:
+      update_tensors(op.inputs)
+      update_tensors(op.outputs)
+    # Delete the tensors
+    for idx in sorted(remove_tensors_idxs, reverse=True):
+      tensors.pop(idx)
+    logging.debug("Removed tensors marked for deletion")
+
+
+def _validate_and_find_int8_quantized_inputs_outputs(model):
+  """Validate that model input is quantized and output is dequantized."""
+  if len(model.subgraphs) > 1:
+    raise ValueError("Model must only have one subgraph. Instead, it has "
+                     "{} subgraphs.".format(len(model.subgraphs)))
+  subgraph = model.subgraphs[0]
+  tensors = subgraph.tensors
+  operators = subgraph.operators
+
+  # Ensure model has atleast one quantize and dequantize operator
+  quant_opcode_idx, dequant_opcode_idx = None, None
+  for idx, opcode in enumerate(model.operatorCodes):
+    if opcode.builtinCode == _schema_fb.BuiltinOperator.QUANTIZE:
+      quant_opcode_idx = idx
+    elif opcode.builtinCode == _schema_fb.BuiltinOperator.DEQUANTIZE:
+      dequant_opcode_idx = idx
+    if quant_opcode_idx is not None and dequant_opcode_idx is not None:
+      break
+  if quant_opcode_idx is None and dequant_opcode_idx is None:
+    raise ValueError("Model is not integer quantized as it does not "
+                     "contain quantize/dequantize operators.")
+
+  # Ensure model inputs and outputs are integer quantized
+  input_quant_ops, output_dequant_ops = [], []
+  for op in operators:
+    # Find input quantize operator
+    if op.opcodeIndex == quant_opcode_idx and op.inputs[0] in subgraph.inputs:
+      pos, float_tensor, int_tensor = \
+          "input", tensors[op.inputs[0]], tensors[op.outputs[0]]
+      input_quant_ops.append(op)
+    # Find output dequantize operator
+    elif op.opcodeIndex == dequant_opcode_idx and \
+        op.outputs[0] in subgraph.outputs:
+      pos, float_tensor, int_tensor = \
+          "output", tensors[op.outputs[0]], tensors[op.inputs[0]]
+      output_dequant_ops.append(op)
+    # Otherwise, ignore
+    else:
+      continue
+    # If found, validate the input/output tensor type
+    if float_tensor.type != _schema_fb.TensorType.FLOAT32:
+      raise ValueError(
+          "Model {} type must be tf.float32. Expected type for tensor with "
+          "name '{}' is tf.float32, instead type is tf.{}".format(
+              pos, float_tensor.name,
+              _convert_tflite_enum_type_to_tf_type(float_tensor.type).name))
+    if int_tensor.type != _schema_fb.TensorType.INT8:
+      raise ValueError(
+          "Model is not integer quantized. Expected type for tensor with "
+          "name '{}' is tf.int8, instead type is tf.{}".format(
+              int_tensor.name,
+              _convert_tflite_enum_type_to_tf_type(int_tensor.type).name))
+
+  return input_quant_ops, output_dequant_ops
+
+
+def modify_integer_quantized_model_io_type(
+    model, inference_input_type=_lite_constants.FLOAT,
+    inference_output_type=_lite_constants.FLOAT):
+  """Modify the float input/output type of an integer quantized model.
+
+  Args:
+    model: An int8 quantized tflite model with float input and output.
+    inference_input_type: tf.DType representing final input type.
+      (default tf.float32)
+    inference_output_type: tf.DType representing final output type.
+      (default tf.float32)
+
+  Returns:
+    An int8 quantized tflite model with modified input and/or output type.
+
+  Raises:
+    ValueError: If the model is not int8 quantized or the inference_input_type
+      and/or inference_input_type is unsupported.
+    RuntimeError: If the modification was unsuccessful.
+
+  """
+  # Return if input and output types default to float
+  if inference_input_type == _lite_constants.FLOAT and \
+      inference_output_type == _lite_constants.FLOAT:
+    return model
+
+  # Validate input and output types
+  if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
+    raise ValueError("The `inference_input_type` should be in {}".format(
+        tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
+  if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
+    raise ValueError("The `inference_output_type` should be in {}".format(
+        tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
+
+  logging.debug(("Attempting to modify the model input from tf.float32 to %s "
+                 "and output from tf.float32 to %s"),
+                _get_dtype_name(inference_input_type),
+                _get_dtype_name(inference_output_type))
+  # Convert the model to an object
+  model = _convert_model_from_bytearray_to_object(model)
+
+  # Validate the integer quantized model
+  input_quant_ops, output_dequant_ops = \
+      _validate_and_find_int8_quantized_inputs_outputs(model)
+
+  # Initialize references and variables
+  if len(model.subgraphs) > 1:
+    raise ValueError("Model must only have one subgraph. Instead, it has "
+                     "{} subgraphs.".format(len(model.subgraphs)))
+  subgraph = model.subgraphs[0]
+  tensors = subgraph.tensors
+  operators = subgraph.operators
+  remove_tensors_idxs = set()
+
+  # Modify model input type
+  if inference_input_type == _lite_constants.QUANTIZED_UINT8:
+    # Change quant op (float to int8) to quant op (uint8 to int8)
+    for op in input_quant_ops:
+      int8_quantization = tensors[op.outputs[0]].quantization
+      uint8_quantization = _schema_fb.QuantizationParametersT()
+      uint8_quantization.scale = [int8_quantization.scale[0]]
+      uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
+      tensors[op.inputs[0]].quantization = uint8_quantization
+      tensors[op.inputs[0]].type = _schema_fb.TensorType.UINT8
+  elif inference_input_type == _lite_constants.INT8:
+    # Remove the inputs and the quant operator
+    for op in input_quant_ops:
+      subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
+      remove_tensors_idxs.add(op.inputs[0])
+      operators.remove(op)
+
+  # Modify model output type
+  if inference_output_type == _lite_constants.QUANTIZED_UINT8:
+    # Change dequant op (int8 to float) to quant op (int8 to uint8)
+    for op in output_dequant_ops:
+      op.opcodeIndex = input_quant_ops[0].opcodeIndex
+      int8_quantization = tensors[op.inputs[0]].quantization
+      uint8_quantization = _schema_fb.QuantizationParametersT()
+      uint8_quantization.scale = [int8_quantization.scale[0]]
+      uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
+      tensors[op.outputs[0]].quantization = uint8_quantization
+      tensors[op.outputs[0]].type = _schema_fb.TensorType.UINT8
+  elif inference_output_type == _lite_constants.INT8:
+    # Remove the outputs and the dequant operator
+    for op in output_dequant_ops:
+      subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
+      remove_tensors_idxs.add(op.outputs[0])
+      operators.remove(op)
+
+  # Remove tensors marked for deletion.
+  _remove_tensors_from_model(model, remove_tensors_idxs)
+
+  # Convert the model to a bytearray
+  model = _convert_model_from_object_to_bytearray(model)
+
+  return model
diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py
index f3c287dd7fc..0e9cbc1e58a 100644
--- a/tensorflow/lite/python/util_test.py
+++ b/tensorflow/lite/python/util_test.py
@@ -19,7 +19,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+import numpy as np
 from six.moves import range
+import tensorflow as tf
 
 from tensorflow.lite.python import lite_constants
 from tensorflow.lite.python import util
@@ -61,6 +64,31 @@ class UtilTest(test_util.TensorFlowTestCase):
     self.assertEqual(
         util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL)
 
+  def testConvertEnumToDtype(self):
+    self.assertEqual(
+        util._convert_tflite_enum_type_to_tf_type(0), dtypes.float32)
+    self.assertEqual(
+        util._convert_tflite_enum_type_to_tf_type(1), dtypes.float16)
+    self.assertEqual(util._convert_tflite_enum_type_to_tf_type(2), dtypes.int32)
+    self.assertEqual(util._convert_tflite_enum_type_to_tf_type(3), dtypes.uint8)
+    self.assertEqual(util._convert_tflite_enum_type_to_tf_type(4), dtypes.int64)
+    self.assertEqual(
+        util._convert_tflite_enum_type_to_tf_type(5), dtypes.string)
+    self.assertEqual(util._convert_tflite_enum_type_to_tf_type(6), dtypes.bool)
+    self.assertEqual(util._convert_tflite_enum_type_to_tf_type(7), dtypes.int16)
+    self.assertEqual(
+        util._convert_tflite_enum_type_to_tf_type(8), dtypes.complex64)
+    self.assertEqual(util._convert_tflite_enum_type_to_tf_type(9), dtypes.int8)
+    self.assertEqual(
+        util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64)
+    with self.assertRaises(ValueError) as error:
+      util._convert_tflite_enum_type_to_tf_type(11)
+    self.assertEqual(
+        "Unsupported enum 11. The valid map of enum to tf.dtypes is : "
+        "{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, "
+        "5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, "
+        "10: tf.float64}", str(error.exception))
+
   def testTensorName(self):
     with ops.Graph().as_default():
       in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
@@ -195,5 +223,140 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
     self.assertEqual([None, 3, 5], tensor.shape.as_list())
 
 
+def _generate_integer_tflite_model():
+  """Define an integer post-training quantized tflite model."""
+  # Load MNIST dataset
+  n = 10  # Number of samples
+  (train_images, train_labels), (test_images, test_labels) = \
+      tf.keras.datasets.mnist.load_data()
+  train_images, train_labels, test_images, test_labels = \
+      train_images[:n], train_labels[:n], test_images[:n], test_labels[:n]
+
+  # Normalize the input image so that each pixel value is between 0 to 1.
+  train_images = train_images / 255.0
+  test_images = test_images / 255.0
+
+  # Define TF model
+  model = tf.keras.Sequential([
+      tf.keras.layers.InputLayer(input_shape=(28, 28)),
+      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
+      tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu"),
+      tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
+      tf.keras.layers.Flatten(),
+      tf.keras.layers.Dense(10)
+  ])
+
+  # Train
+  model.compile(
+      optimizer="adam",
+      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+      metrics=["accuracy"])
+
+  model.fit(
+      train_images,
+      train_labels,
+      epochs=1,
+      validation_split=0.1,
+  )
+
+  # Convert TF Model to an Integer Quantized TFLite Model
+  converter = tf.lite.TFLiteConverter.from_keras_model(model)
+  converter.optimizations = {tf.lite.Optimize.DEFAULT}
+  def representative_dataset_gen():
+    for _ in range(2):
+      yield [
+          np.random.uniform(low=0, high=1, size=(1, 28, 28)).astype(
+              np.float32)
+      ]
+  converter.representative_dataset = representative_dataset_gen
+  converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
+  tflite_model = converter.convert()
+
+  return tflite_model
+
+
+def _test_param_modify_integer_model_io_type():
+  """Function to generate parameterized inputs for testing."""
+  params = []
+  str_template = "_{}{}{}"
+  map_model_type = {
+      "PostTraining": True,
+      # "DuringTraining": False,
+  }
+  map_types = {
+      "": lite_constants.FLOAT,
+      "INT8": lite_constants.INT8,
+      "UINT8": lite_constants.QUANTIZED_UINT8
+  }
+  for k1, v1 in map_model_type.items():
+    for k2, v2 in map_types.items():
+      istr = "_Input{}".format(k2) if k2 else ""
+      for k3, v3 in map_types.items():
+        ostr = "_Output{}".format(k3) if k3 else "" if istr else "_NoUpdate"
+        params.append((str_template.format(k1, istr, ostr), v1, v2, v3))
+  return params
+
+
+# TODO(b/161174063):  Merge tests for integer input/output type
+class UtilModifyIntegerQuantizedModelIOTypeTest(
+    test_util.TensorFlowTestCase, parameterized.TestCase):
+
+  @classmethod
+  def setUpClass(cls):
+    super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass()
+    cls.post_train_integer_model = _generate_integer_tflite_model()
+
+  @parameterized.named_parameters(_test_param_modify_integer_model_io_type())
+  def test(self, is_post_train, in_tftype, out_tftype):
+    """Modify the float input/output type of an integer quantized model."""
+
+    def _run_tflite_inference(model, in_tftype, out_tftype):
+      """Run inference on a model with a specific input/output type."""
+      # Load TFLite model and allocate tensors.
+      interpreter = tf.lite.Interpreter(model_content=model)
+      interpreter.allocate_tensors()
+      input_details = interpreter.get_input_details()[0]
+      output_details = interpreter.get_output_details()[0]
+
+      # Validate TFLite model input and output types
+      self.assertEqual(input_details["dtype"], in_tftype.as_numpy_dtype)
+      self.assertEqual(output_details["dtype"], out_tftype.as_numpy_dtype)
+
+      # Define Input
+      np.random.seed(0)
+      input_data = np.random.uniform(low=0, high=1, size=(1, 28, 28))
+      input_data = input_data.astype(np.float32)
+      if input_details["dtype"] != np.float32:
+        # quantize float to int
+        scale, zero_point = input_details["quantization"]
+        input_data = input_data / scale + zero_point
+        input_data = input_data.astype(input_details["dtype"])
+
+      # Run Inference
+      interpreter.set_tensor(input_details["index"], input_data)
+      interpreter.invoke()
+
+      # Get output
+      output_data = interpreter.get_tensor(output_details["index"])[0]
+      if output_details["dtype"] != np.float32:
+        # dequantize int to float
+        scale, zero_point = output_details["quantization"]
+        output_data = output_data.astype(np.float32)
+        output_data = (output_data - zero_point) * scale
+
+      return output_data
+
+    model = self.__class__.post_train_integer_model if is_post_train else None
+    # Run model inference with float input output type
+    output_data = _run_tflite_inference(model, tf.float32, tf.float32)
+    # Run model inference with modified integer input output type
+    model_io = util.modify_integer_quantized_model_io_type(
+        model, in_tftype, out_tftype)
+    output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
+
+     # Validate that both the outputs are the same
+    self.assertTrue(np.allclose(output_data, output_io_data, atol=1.0))
+
+
 if __name__ == "__main__":
   test.main()