From 55912083e2f16087c2f29394acf8a6a4811a2ce0 Mon Sep 17 00:00:00 2001
From: Nupur Garg <nupurgarg@google.com>
Date: Fri, 31 Jan 2020 09:53:26 -0800
Subject: [PATCH] Add support for unknown dimensions to TFLite using MLIR
 converter.

PiperOrigin-RevId: 292563455
Change-Id: Ib5700cfe6faee177027329e32089abb3bcc9adaf
---
 .../mlir/lite/flatbuffer_translate.cc         | 28 ++++++-
 tensorflow/lite/c/common.c                    |  5 ++
 tensorflow/lite/c/common.h                    |  6 ++
 tensorflow/lite/c/common_test.cc              |  2 +
 tensorflow/lite/core/subgraph.cc              |  5 +-
 tensorflow/lite/core/subgraph.h               | 16 ++--
 tensorflow/lite/model.cc                      | 13 ++-
 tensorflow/lite/python/convert.py             | 12 ++-
 tensorflow/lite/python/interpreter.py         |  2 +
 .../interpreter_wrapper.cc                    | 17 ++++
 .../interpreter_wrapper/interpreter_wrapper.h |  1 +
 tensorflow/lite/python/lite.py                | 54 +++++++++----
 tensorflow/lite/python/lite_test.py           | 46 ++++++++++-
 tensorflow/lite/python/lite_v2_test.py        | 79 ++++++++++++++++++-
 tensorflow/lite/schema/schema.fbs             |  4 +
 tensorflow/lite/schema/schema_generated.h     | 28 +++++--
 .../benchmark/experimental/c/c_api_types.h    |  6 ++
 17 files changed, 284 insertions(+), 40 deletions(-)

diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
index 14e99ce76f8..7b909c0c857 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
@@ -610,6 +610,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
   };
 
   std::vector<int32_t> shape;
+  std::vector<int32_t> shape_signature;
   if (type.hasStaticShape()) {
     llvm::ArrayRef<int64_t> shape_ref = type.getShape();
     if (mlir::failed(check_shape(shape_ref))) return llvm::None;
@@ -627,7 +628,17 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
 
       shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
     }
+  } else if (type.hasRank()) {
+    llvm::ArrayRef<int64_t> shape_ref = type.getShape();
+    if (mlir::failed(check_shape(shape_ref))) return llvm::None;
+
+    shape.reserve(shape_ref.size());
+    for (auto& dim : shape_ref) {
+      shape.push_back(dim == -1 ? 1 : dim);
+    }
+    shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
   }
+
   Type element_type = type.getElementType();
   tflite::TensorType tflite_element_type =
       GetTFLiteType(type.getElementType()).ValueOrDie();
@@ -664,10 +675,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
       break;
     }
   }
-  return tflite::CreateTensor(
-      builder_, builder_.CreateVector(shape), tflite_element_type,
-      (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
-      /*is_variable=*/is_variable);
+
+  if (shape_signature.empty()) {
+    return tflite::CreateTensor(
+        builder_, builder_.CreateVector(shape), tflite_element_type,
+        (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
+        /*is_variable=*/is_variable);
+  } else {
+    return tflite::CreateTensor(
+        builder_, builder_.CreateVector(shape), tflite_element_type,
+        (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
+        /*is_variable=*/is_variable, /*sparsity=*/0,
+        /*shape_signature=*/builder_.CreateVector(shape_signature));
+  }
 }
 
 BufferOffset<tflite::Operator> Translator::BuildIfOperator(
diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c
index 1721e75d7ce..7196f32b62a 100644
--- a/tensorflow/lite/c/common.c
+++ b/tensorflow/lite/c/common.c
@@ -140,6 +140,11 @@ void TfLiteTensorFree(TfLiteTensor* t) {
   if (t->dims) TfLiteIntArrayFree(t->dims);
   t->dims = NULL;
 
+  if (t->dims_signature) {
+    TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
+  }
+  t->dims_signature = NULL;
+
   TfLiteQuantizationFree(&t->quantization);
   TfLiteSparsityFree(t->sparsity);
   t->sparsity = NULL;
diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h
index 4d7fe8c78a8..023e1871d2b 100644
--- a/tensorflow/lite/c/common.h
+++ b/tensorflow/lite/c/common.h
@@ -391,6 +391,12 @@ typedef struct TfLiteTensor {
   // This is optional. The field is NULL if a tensor is dense.
   // WARNING: This is an experimental interface that is subject to change.
   TfLiteSparsity* sparsity;
+
+  // Optional. Encodes shapes with unknown dimensions with -1. This field is
+  // only populated when unknown dimensions exist in a read-write tensor (i.e.
+  // an input or output tensor). (e.g.  `dims` contains [1, 1, 1, 3] and
+  // `dims_signature` contains [1, -1, -1, 3]).
+  const TfLiteIntArray* dims_signature;
 } TfLiteTensor;
 
 #ifndef TF_LITE_STATIC_MEMORY
diff --git a/tensorflow/lite/c/common_test.cc b/tensorflow/lite/c/common_test.cc
index 7230adff0e9..0421b50c05e 100644
--- a/tensorflow/lite/c/common_test.cc
+++ b/tensorflow/lite/c/common_test.cc
@@ -95,6 +95,7 @@ TEST(Quantization, TestQuantizationFree) {
   // Set these values, otherwise TfLiteTensorFree has uninitialized values.
   t.allocation_type = kTfLiteArenaRw;
   t.dims = nullptr;
+  t.dims_signature = nullptr;
   t.quantization.type = kTfLiteAffineQuantization;
   t.sparsity = nullptr;
   auto* params = reinterpret_cast<TfLiteAffineQuantization*>(
@@ -110,6 +111,7 @@ TEST(Sparsity, TestSparsityFree) {
   // Set these values, otherwise TfLiteTensorFree has uninitialized values.
   t.allocation_type = kTfLiteArenaRw;
   t.dims = nullptr;
+  t.dims_signature = nullptr;
 
   // A dummy CSR sparse matrix.
   t.sparsity = static_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity)));
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index 0f94ca0ae3c..e49b9ad9a59 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -1074,7 +1074,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
 // to Interpreter.
 TfLiteStatus Subgraph::SetTensorParametersReadWrite(
     int tensor_index, TfLiteType type, const char* name, const size_t rank,
-    const int* dims, TfLiteQuantization quantization, bool is_variable) {
+    const int* dims, TfLiteQuantization quantization, bool is_variable,
+    const size_t rank_dims_signature, const int* dims_signature) {
   // Ensure quantization cleanup on failure.
   ScopedTfLiteQuantization scoped_quantization(&quantization);
   if (state_ == kStateInvokableAndImmutable) {
@@ -1114,6 +1115,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
   // TODO(suharshs): Update TfLiteTensorReset to include the new quantization
   // if there are other required callers.
   tensor.quantization = *scoped_quantization.release();
+  tensor.dims_signature =
+      ConvertArrayToTfLiteIntArray(rank_dims_signature, dims_signature);
   return kTfLiteOk;
 }
 
diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h
index 58c125a5f98..021439e827b 100644
--- a/tensorflow/lite/core/subgraph.h
+++ b/tensorflow/lite/core/subgraph.h
@@ -114,15 +114,17 @@ class Subgraph {
   inline TfLiteStatus SetTensorParametersReadWrite(
       int tensor_index, TfLiteType type, const char* name,
       const std::vector<int>& dims, TfLiteQuantization quantization,
-      bool is_variable = false) {
+      bool is_variable = false, const size_t rank_dims_signature = 0,
+      const int* dims_signature = nullptr) {
     return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
-                                        dims.data(), quantization, is_variable);
+                                        dims.data(), quantization, is_variable,
+                                        rank_dims_signature, dims_signature);
   }
-  TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
-                                            const char* name, const size_t rank,
-                                            const int* dims,
-                                            TfLiteQuantization quantization,
-                                            bool is_variable = false);
+  TfLiteStatus SetTensorParametersReadWrite(
+      int tensor_index, TfLiteType type, const char* name, const size_t rank,
+      const int* dims, TfLiteQuantization quantization,
+      bool is_variable = false, const size_t rank_dims_signature = 0,
+      const int* dims_signature = nullptr);
 
   // WARNING: Experimental interface, subject to change
   // Overrides execution plan. This bounds checks indices sent in.
diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc
index 0556f47adba..04d064d0933 100644
--- a/tensorflow/lite/model.cc
+++ b/tensorflow/lite/model.cc
@@ -563,6 +563,13 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
       status = kTfLiteError;
     }
 
+    size_t dims_signature_rank = 0;
+    const int* dims_signature_data = nullptr;
+    if (tensor->shape_signature()) {
+      dims_signature_rank = tensor->shape_signature()->Length();
+      dims_signature_data = tensor->shape_signature()->data();
+    }
+
     bool is_variable = tensor->is_variable();
     if (buffer_ptr) {
       if (is_variable) {
@@ -590,9 +597,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
         status = kTfLiteError;
       }
     } else {
-      if (subgraph->SetTensorParametersReadWrite(i, type, get_name(tensor),
-                                                 dims, quantization,
-                                                 is_variable) != kTfLiteOk) {
+      if (subgraph->SetTensorParametersReadWrite(
+              i, type, get_name(tensor), dims, quantization, is_variable,
+              dims_signature_rank, dims_signature_data) != kTfLiteOk) {
         error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
                                 i);
         status = kTfLiteError;
diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py
index 2fe4d172487..4813edef126 100644
--- a/tensorflow/lite/python/convert.py
+++ b/tensorflow/lite/python/convert.py
@@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco
 from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
 from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
 from tensorflow.lite.toco import types_pb2 as _types_pb2
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.platform import resource_loader as _resource_loader
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export as _tf_export
@@ -384,7 +385,16 @@ def build_toco_convert_protos(input_tensors,
       shape = input_tensor.shape
     else:
       shape = input_shapes[idx]
-    input_array.shape.dims.extend(list(map(int, shape)))
+
+    # Create shapes with -1 for unknown dimensions.
+    dims = []
+    for dim in shape:
+      if (dim is None or
+          (isinstance(dim, tensor_shape.Dimension) and dim.value is None)):
+        dims.append(-1)
+      else:
+        dims.append(int(dim))
+    input_array.shape.dims.extend(dims)
 
   for output_tensor in output_tensors:
     model.output_arrays.append(util.get_tensor_name(output_tensor))
diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py
index 153b6f17c3c..4acedabeab9 100644
--- a/tensorflow/lite/python/interpreter.py
+++ b/tensorflow/lite/python/interpreter.py
@@ -320,6 +320,7 @@ class Interpreter(object):
     tensor_index = int(tensor_index)
     tensor_name = self._interpreter.TensorName(tensor_index)
     tensor_size = self._interpreter.TensorSize(tensor_index)
+    tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index)
     tensor_type = self._interpreter.TensorType(tensor_index)
     tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
     tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
@@ -332,6 +333,7 @@ class Interpreter(object):
         'name': tensor_name,
         'index': tensor_index,
         'shape': tensor_size,
+        'shape_signature': tensor_size_signature,
         'dtype': tensor_type,
         'quantization': tensor_quantization,
         'quantization_parameters': {
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 6b1bf34ea7d..58fb17e4f9b 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -301,6 +301,23 @@ PyObject* InterpreterWrapper::TensorSize(int i) const {
   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
 }
 
+PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
+  TFLITE_PY_ENSURE_VALID_INTERPRETER();
+  TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
+
+  const TfLiteTensor* tensor = interpreter_->tensor(i);
+  const int32_t* size_signature_data = nullptr;
+  int32_t size_signature_size = 0;
+  if (tensor->dims_signature != nullptr) {
+    size_signature_data = tensor->dims_signature->data;
+    size_signature_size = tensor->dims_signature->size;
+  }
+  PyObject* np_array =
+      PyArrayFromIntVector(size_signature_data, size_signature_size);
+
+  return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
+}
+
 PyObject* InterpreterWrapper::TensorQuantization(int i) const {
   TFLITE_PY_ENSURE_VALID_INTERPRETER();
   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
index be9086f307b..c37d3e998cd 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -69,6 +69,7 @@ class InterpreterWrapper {
   std::string TensorName(int i) const;
   PyObject* TensorType(int i) const;
   PyObject* TensorSize(int i) const;
+  PyObject* TensorSizeSignature(int i) const;
   // Deprecated in favor of TensorQuantizationScales, below.
   PyObject* TensorQuantization(int i) const;
   PyObject* TensorQuantizationParameters(int i) const;
diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py
index 657cfea1bb8..3965a4ac275 100644
--- a/tensorflow/lite/python/lite.py
+++ b/tensorflow/lite/python/lite.py
@@ -261,6 +261,16 @@ class TFLiteConverterBase(object):
         self.representative_dataset.input_gen, inference_input_type,
         inference_output_type, allow_float, enable_mlir_quantizer)
 
+  def _is_unknown_shapes_allowed(self):
+    # TODO(b/128319310): Investigate which quantization methods work.
+    if self._any_optimization_enabled():
+      return False
+
+    # Unknown dimensions are only allowed with the new converter.
+    if not self.experimental_new_converter:
+      return False
+    return True
+
   def _get_base_converter_args(self):
     """Returns the base converter args.
 
@@ -456,19 +466,21 @@ class TFLiteConverterV2(TFLiteConverterBase):
         config=self._grappler_config(),
         graph=frozen_func.graph)
 
-    # Checks dimensions in input tensor.
-    for tensor in input_tensors:
-      # Note that shape_list might be empty for scalar shapes.
-      shape_list = tensor.shape.as_list()
-      if None in shape_list[1:]:
-        raise ValueError(
-            "None is only supported in the 1st dimension. Tensor '{0}' has "
-            "invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list))
-      elif shape_list and shape_list[0] is None:
-        # Set the batch size to 1 if undefined.
-        shape = tensor.shape.as_list()
-        shape[0] = 1
-        tensor.set_shape(shape)
+    if not self._is_unknown_shapes_allowed():
+      # Checks dimensions in input tensor.
+      for tensor in input_tensors:
+        # Note that shape_list might be empty for scalar shapes.
+        shape_list = tensor.shape.as_list()
+        if None in shape_list[1:]:
+          raise ValueError(
+              "None is only supported in the 1st dimension. Tensor '{0}' has "
+              "invalid shape '{1}'.".format(
+                  _get_tensor_name(tensor), shape_list))
+        elif shape_list and shape_list[0] is None:
+          # Set the batch size to 1 if undefined.
+          shape = tensor.shape.as_list()
+          shape[0] = 1
+          tensor.set_shape(shape)
 
     self._validate_quantization()
     self._validate_representative_dataset()
@@ -942,7 +954,7 @@ class TFLiteConverter(TFLiteConverterBase):
         None value for dimension in input_tensor.
     """
     # Checks dimensions in input tensor.
-    if self._has_valid_tensors():
+    if not self._is_unknown_shapes_allowed() and self._has_valid_tensors():
       for tensor in self._input_tensors:
         shape = tensor.shape
         if not shape:
@@ -1115,6 +1127,20 @@ class TFLiteConverter(TFLiteConverterBase):
         shape[0] = batch_size
         tensor.set_shape(shape)
 
+  def _is_unknown_shapes_allowed(self):
+    if not super(TFLiteConverter, self)._is_unknown_shapes_allowed():
+      return False
+
+    # `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by
+    # the MLIR converter.
+    if self.conversion_summary_dir:
+      logging.warning(
+          "`conversion_summary_dir` does not work with unknown shapes. "
+          "Graphs with unknown shapes might be different than when this flag "
+          "is disabled.")
+      return False
+    return True
+
 
 @_tf_export(v1=["lite.TocoConverter"])
 class TocoConverter(object):
diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py
index f3a29fb97a7..8c1f10af530 100644
--- a/tensorflow/lite/python/lite_test.py
+++ b/tensorflow/lite/python/lite_test.py
@@ -318,9 +318,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
       out_tensor = in_tensor + in_tensor
       sess = session.Session()
 
-    # Test None as shape.
+    # Test None as shape when dynamic shapes are disabled. Run with TOCO in
+    # order to invoke shape checking code.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                   [out_tensor])
+    converter.experimental_new_converter = False
     with self.assertRaises(ValueError) as error:
       converter.convert()
     self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
@@ -375,9 +377,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
       out_tensor = in_tensor + in_tensor
       sess = session.Session()
 
-    # Test invalid shape. None after 1st dimension.
+    # Test invalid shape. None after 1st dimension. Run with TOCO in order to
+    # invoke shape checking code.
     converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                   [out_tensor])
+    converter.experimental_new_converter = False
     with self.assertRaises(ValueError) as error:
       converter.convert()
     self.assertEqual(
@@ -385,6 +389,44 @@ class FromSessionTest(TestModels, parameterized.TestCase):
         '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
         str(error.exception))
 
+  def testSizeNone(self):
+    with ops.Graph().as_default():
+      in_tensor = array_ops.placeholder(
+          shape=[1, None, 16, 3], dtype=dtypes.float32)
+      out_tensor = in_tensor + in_tensor
+      sess = session.Session()
+
+    # Test None after 1st dimension.
+    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+                                                  [out_tensor])
+    converter.experimental_new_converter = True
+    tflite_model = converter.convert()
+
+    # Check values from converted model.
+    interpreter = Interpreter(model_content=tflite_model)
+    input_details = interpreter.get_input_details()
+    self.assertLen(input_details, 1)
+    self.assertEqual('Placeholder', input_details[0]['name'])
+    self.assertEqual(np.float32, input_details[0]['dtype'])
+    self.assertTrue(([1, 1, 16, 3] == input_details[0]['shape']).all())
+    self.assertTrue(([1, -1, 16,
+                      3] == input_details[0]['shape_signature']).all())
+    self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+    # Resize tensor and invoke.
+    interpreter.resize_tensor_input(0, [1, 16, 16, 3])
+    interpreter.allocate_tensors()
+    interpreter.invoke()
+
+    input_details = interpreter.get_input_details()
+    self.assertLen(input_details, 1)
+    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+    self.assertTrue(([1, -1, 16,
+                      3] == input_details[0]['shape_signature']).all())
+
+    output_details = interpreter.get_output_details()
+    self.assertFalse(output_details[0]['shape_signature'])
+
   def testBatchSizeValid(self):
     with ops.Graph().as_default():
       in_tensor = array_ops.placeholder(
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index 1f0156d6524..bb149399ef9 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -54,12 +54,28 @@ from tensorflow.python.training.tracking import tracking
 
 class TestModels(test_util.TensorFlowTestCase, parameterized.TestCase):
 
-  def _evaluateTFLiteModel(self, tflite_model, input_data):
-    """Evaluates the model on the `input_data`."""
+  def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None):
+    """Evaluates the model on the `input_data`.
+
+    Args:
+      tflite_model: TensorFlow Lite model.
+      input_data: List of EagerTensor const ops containing the input data for
+        each input tensor.
+      input_shapes: List of tuples representing the `shape_signature` and the
+        new shape of each input tensor that has unknown dimensions.
+
+    Returns:
+      [np.ndarray]
+    """
     interpreter = Interpreter(model_content=tflite_model)
+    input_details = interpreter.get_input_details()
+    if input_shapes:
+      for idx, (shape_signature, final_shape) in enumerate(input_shapes):
+        self.assertTrue(
+            (input_details[idx]['shape_signature'] == shape_signature).all())
+        interpreter.resize_tensor_input(idx, final_shape)
     interpreter.allocate_tensors()
 
-    input_details = interpreter.get_input_details()
     output_details = interpreter.get_output_details()
 
     for input_tensor, tensor_data in zip(input_details, input_data):
@@ -795,5 +811,62 @@ class GrapplerTest(TestModels):
     actual_value = self._evaluateTFLiteModel(hybrid_tflite_model, [input_data])
     np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
 
+
+class UnknownShapes(TestModels):
+
+  @test_util.run_v2_only
+  def testMatMul(self):
+    input_data = constant_op.constant(
+        np.array(np.random.random_sample((10, 4)), dtype=np.float32))
+
+    @def_function.function(input_signature=[
+        tensor_spec.TensorSpec(shape=[None, 4], dtype=dtypes.float32)
+    ])
+    def model(in_tensor):
+      shape = array_ops.shape_v2(in_tensor)
+      fill = array_ops.transpose_v2(array_ops.fill(shape, 1.))
+      return math_ops.matmul(fill, in_tensor)
+
+    concrete_func = model.get_concrete_function()
+
+    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
+    converter.experimental_new_converter = True
+    tflite_model = converter.convert()
+
+    # Check values from converted model.
+    expected_value = concrete_func(input_data)
+    actual_value = self._evaluateTFLiteModel(
+        tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])
+    np.testing.assert_almost_equal(
+        expected_value.numpy(), actual_value[0], decimal=6)
+
+  def testBatchMatMul(self):
+    self.skipTest('BatchMatMulV2 ranked tensor check fails.')
+    input_data_1 = constant_op.constant(
+        np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
+    input_data_2 = constant_op.constant(
+        np.array(np.random.random_sample((1, 2, 256)), dtype=np.float32))
+
+    @def_function.function(input_signature=[
+        tensor_spec.TensorSpec(shape=[1, 256, 256], dtype=dtypes.float32),
+        tensor_spec.TensorSpec(shape=[1, None, 256], dtype=dtypes.float32)
+    ])
+    def model(in_tensor_1, in_tensor_2):
+      return math_ops.matmul(in_tensor_1, in_tensor_2)
+
+    concrete_func = model.get_concrete_function()
+
+    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
+    converter.experimental_new_converter = True
+    tflite_model = converter.convert()
+
+    # Check values from converted model.
+    expected_value = concrete_func(input_data_1, input_data_2)
+    actual_value = self._evaluateTFLiteModel(
+        tflite_model, [input_data_1, input_data_2],
+        input_shapes={1: [1, 2, 256]})
+    np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs
index 0f052b5e5b3..e7d5eaed29f 100644
--- a/tensorflow/lite/schema/schema.fbs
+++ b/tensorflow/lite/schema/schema.fbs
@@ -178,6 +178,10 @@ table Tensor {
   // Parameters to encode a sparse tensor. See the example in
   // tensorflow/lite/testdata/sparse_tensor.json.
   sparsity:SparsityParameters;  // Optional.
+
+  // Encodes `shape` with unknown dimensions. Unknown dimensions are
+  // represented with -1.
+  shape_signature:[int]; // Optional.
 }
 
 // A list of builtin operators. Builtin operators are slightly faster than custom
diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h
index 5daa0782b3a..b91a2f0343d 100755
--- a/tensorflow/lite/schema/schema_generated.h
+++ b/tensorflow/lite/schema/schema_generated.h
@@ -3175,6 +3175,7 @@ struct TensorT : public flatbuffers::NativeTable {
   std::unique_ptr<QuantizationParametersT> quantization;
   bool is_variable;
   std::unique_ptr<SparsityParametersT> sparsity;
+  std::vector<int32_t> shape_signature;
   TensorT()
       : type(TensorType_FLOAT32),
         buffer(0),
@@ -3191,7 +3192,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
     VT_NAME = 10,
     VT_QUANTIZATION = 12,
     VT_IS_VARIABLE = 14,
-    VT_SPARSITY = 16
+    VT_SPARSITY = 16,
+    VT_SHAPE_SIGNATURE = 18
   };
   const flatbuffers::Vector<int32_t> *shape() const {
     return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
@@ -3214,6 +3216,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   const SparsityParameters *sparsity() const {
     return GetPointer<const SparsityParameters *>(VT_SPARSITY);
   }
+  const flatbuffers::Vector<int32_t> *shape_signature() const {
+    return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE_SIGNATURE);
+  }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_SHAPE) &&
@@ -3227,6 +3232,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
            VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
            VerifyOffset(verifier, VT_SPARSITY) &&
            verifier.VerifyTable(sparsity()) &&
+           VerifyOffset(verifier, VT_SHAPE_SIGNATURE) &&
+           verifier.VerifyVector(shape_signature()) &&
            verifier.EndTable();
   }
   TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -3258,6 +3265,9 @@ struct TensorBuilder {
   void add_sparsity(flatbuffers::Offset<SparsityParameters> sparsity) {
     fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity);
   }
+  void add_shape_signature(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature) {
+    fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature);
+  }
   explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -3278,8 +3288,10 @@ inline flatbuffers::Offset<Tensor> CreateTensor(
     flatbuffers::Offset<flatbuffers::String> name = 0,
     flatbuffers::Offset<QuantizationParameters> quantization = 0,
     bool is_variable = false,
-    flatbuffers::Offset<SparsityParameters> sparsity = 0) {
+    flatbuffers::Offset<SparsityParameters> sparsity = 0,
+    flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature = 0) {
   TensorBuilder builder_(_fbb);
+  builder_.add_shape_signature(shape_signature);
   builder_.add_sparsity(sparsity);
   builder_.add_quantization(quantization);
   builder_.add_name(name);
@@ -3298,9 +3310,11 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
     const char *name = nullptr,
     flatbuffers::Offset<QuantizationParameters> quantization = 0,
     bool is_variable = false,
-    flatbuffers::Offset<SparsityParameters> sparsity = 0) {
+    flatbuffers::Offset<SparsityParameters> sparsity = 0,
+    const std::vector<int32_t> *shape_signature = nullptr) {
   auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
   auto name__ = name ? _fbb.CreateString(name) : 0;
+  auto shape_signature__ = shape_signature ? _fbb.CreateVector<int32_t>(*shape_signature) : 0;
   return tflite::CreateTensor(
       _fbb,
       shape__,
@@ -3309,7 +3323,8 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
       name__,
       quantization,
       is_variable,
-      sparsity);
+      sparsity,
+      shape_signature__);
 }
 
 flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -10275,6 +10290,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t
   { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr<QuantizationParametersT>(_e->UnPack(_resolver)); };
   { auto _e = is_variable(); _o->is_variable = _e; };
   { auto _e = sparsity(); if (_e) _o->sparsity = std::unique_ptr<SparsityParametersT>(_e->UnPack(_resolver)); };
+  { auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } };
 }
 
 inline flatbuffers::Offset<Tensor> Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -10292,6 +10308,7 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
   auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0;
   auto _is_variable = _o->is_variable;
   auto _sparsity = _o->sparsity ? CreateSparsityParameters(_fbb, _o->sparsity.get(), _rehasher) : 0;
+  auto _shape_signature = _o->shape_signature.size() ? _fbb.CreateVector(_o->shape_signature) : 0;
   return tflite::CreateTensor(
       _fbb,
       _shape,
@@ -10300,7 +10317,8 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
       _name,
       _quantization,
       _is_variable,
-      _sparsity);
+      _sparsity,
+      _shape_signature);
 }
 
 inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
index 4d7fe8c78a8..023e1871d2b 100644
--- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
+++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
@@ -391,6 +391,12 @@ typedef struct TfLiteTensor {
   // This is optional. The field is NULL if a tensor is dense.
   // WARNING: This is an experimental interface that is subject to change.
   TfLiteSparsity* sparsity;
+
+  // Optional. Encodes shapes with unknown dimensions with -1. This field is
+  // only populated when unknown dimensions exist in a read-write tensor (i.e.
+  // an input or output tensor). (e.g.  `dims` contains [1, 1, 1, 3] and
+  // `dims_signature` contains [1, -1, -1, 3]).
+  const TfLiteIntArray* dims_signature;
 } TfLiteTensor;
 
 #ifndef TF_LITE_STATIC_MEMORY