From 37c9633c508cc76ca662f9f1da7dff07d76abe44 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 11 Dec 2020 10:22:12 -0800
Subject: [PATCH] Allow state tensors to use device memories in NNAPI delegate.

PiperOrigin-RevId: 347025011
Change-Id: I7441e254d020f89102b85023c5816894b3684b1f
---
 .../lite/delegates/nnapi/nnapi_delegate.cc    | 447 +++++++++++++-----
 .../lite/delegates/nnapi/nnapi_delegate.h     |  57 +++
 .../delegates/nnapi/nnapi_delegate_kernel.h   |  29 ++
 .../delegates/nnapi/nnapi_delegate_test.cc    | 143 ++++--
 4 files changed, 528 insertions(+), 148 deletions(-)

diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index 89846501789..a73b44bfcbd 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -449,6 +449,78 @@ ANeuralNetworksOperandType ConvertTensorTypeToNNType(
   return nn_operand_type;
 }
 
+// Copy the CPU buffer of the input tensor to a shared memory address. Will
+// apply data type conversion if needed. The returned tensor_size is the size
+// after the potential data type conversion.
+TfLiteStatus CopyOrConvertInputData(TfLiteContext* context,
+                                    TfLiteType ann_type_equivalent,
+                                    bool use_int8_asymm_signed,
+                                    TfLiteTensor* tensor, uint8_t* dst,
+                                    int* tensor_size) {
+  if (ann_type_equivalent != kTfLiteNoType) {
+    const auto num_elements = NumElements(tensor);
+    if (tensor->type == kTfLiteUInt8 && ann_type_equivalent == kTfLiteInt32) {
+      for (int i = 0; i < num_elements; ++i) {
+        reinterpret_cast<int32_t*>(dst)[i] =
+            static_cast<const int32_t>(tensor->data.uint8[i]);
+      }
+    } else if (tensor->type == kTfLiteInt8 &&
+               ann_type_equivalent == kTfLiteUInt8) {
+      // Explicitly convert int8 values to uint8 values.
+      for (int i = 0; i < num_elements; ++i) {
+        dst[i] = static_cast<const uint8_t>(
+            static_cast<int32_t>(tensor->data.int8[i]) + 128);
+      }
+    } else if (tensor->type == kTfLiteInt8 &&
+               ann_type_equivalent == kTfLiteInt32) {
+      if (use_int8_asymm_signed) {
+        for (int i = 0; i < num_elements; ++i) {
+          reinterpret_cast<int32_t*>(dst)[i] =
+              static_cast<const int32_t>(tensor->data.int8[i]);
+        }
+      } else {
+        for (int i = 0; i < num_elements; ++i) {
+          reinterpret_cast<int32_t*>(dst)[i] =
+              static_cast<const int32_t>(tensor->data.int8[i]) + 128;
+        }
+      }
+    } else {
+      TF_LITE_KERNEL_LOG(
+          context,
+          "NN API Delegate: unsupported tensor types conversion: "
+          "from type code %d to type code %d.\n",
+          tensor->type, ann_type_equivalent);
+      return kTfLiteError;
+    }
+    size_t type_size;
+    TF_LITE_ENSURE_OK(context,
+                      GetSizeOfType(context, ann_type_equivalent, &type_size));
+    *tensor_size = NumElements(tensor) * type_size;
+  } else {
+    // copy data to pre-allocated shared memory.
+    memcpy(dst, tensor->data.raw, tensor->bytes);
+    *tensor_size = tensor->bytes;
+  }
+  return kTfLiteOk;
+}
+
+// Copy into the CPU buffer of the output tensor from a shared memory address.
+// Will apply data type conversion if needed.
+TfLiteStatus CopyOrConvertOutputData(TfLiteType ann_type_equivalent,
+                                     const uint8_t* src, TfLiteTensor* tensor) {
+  if (tensor->type == kTfLiteInt8 && ann_type_equivalent == kTfLiteUInt8) {
+    // Explicitly convert uint8 values to int8 values.
+    int8_t* output_ptr = tensor->data.int8;
+    const auto num_elements = NumElements(tensor);
+    for (int i = 0; i < num_elements; ++i) {
+      output_ptr[i] = static_cast<int8_t>(static_cast<int32_t>(src[i]) - 128);
+    }
+  } else {
+    memcpy(tensor->data.raw, src, tensor->bytes);
+  }
+  return kTfLiteOk;
+}
+
 constexpr size_t kDefaultByteAlignmentForNNAPI = 16;
 
 static size_t getNumPaddingBytes(size_t byte_size) {
@@ -3642,6 +3714,7 @@ TfLiteStatus NNAPIDelegateKernel::Prepare(TfLiteContext* context,
     return kTfLiteOk;
   }
 
+  const auto& delegate_data = StatefulNnApiDelegate::GetData(node->delegate);
   ANeuralNetworksCompilation* compilation = nullptr;
   if (!nnapi_devices_.empty()) {
     // Compile for the selected accelerator.
@@ -3709,6 +3782,67 @@ TfLiteStatus NNAPIDelegateKernel::Prepare(TfLiteContext* context,
   }
   RETURN_TFLITE_ERROR_IF_NN_ERROR(context, finish_result,
                                   "completing NNAPI compilation", nnapi_errno);
+
+  const bool use_device_memory_for_state_tensors =
+      nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI13 &&
+      delegate_data.use_device_memory_for_state_tensors &&
+      delegate_data.single_partition_delegated &&
+      // State tensors with dynamic shapes are currently not supported.
+      std::all_of(model_state_tfl_inputs_.begin(),
+                  model_state_tfl_inputs_.end(), [&context](int tfl_index) {
+                    TfLiteTensor* tensor = &context->tensors[tfl_index];
+                    return !IsDynamicTensor(tensor);
+                  });
+  if (use_device_memory_for_state_tensors) {
+    for (int tfl_index : model_state_tfl_inputs_) {
+      auto& info = nn_state_tensor_info_map_.at(tfl_index);
+
+      // prepare device memory descriptor
+      ANeuralNetworksMemoryDesc* desc = nullptr;
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context, nnapi_->ANeuralNetworksMemoryDesc_create(&desc),
+          "creating device memory descriptor", nnapi_errno);
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context,
+          nnapi_->ANeuralNetworksMemoryDesc_addInputRole(
+              desc, compilation, info.nn_input_index, 1.0f),
+          "adding input role to the device memory descriptor", nnapi_errno);
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context,
+          nnapi_->ANeuralNetworksMemoryDesc_addOutputRole(
+              desc, compilation, info.nn_output_index, 1.0f),
+          "adding output role to the device memory descriptor", nnapi_errno);
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context, nnapi_->ANeuralNetworksMemoryDesc_finish(desc),
+          "finishing device memory descriptor", nnapi_errno);
+
+      // allocate two device memories for each state tensor
+      ANeuralNetworksMemory* state_input_memory = nullptr;
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context,
+          nnapi_->ANeuralNetworksMemory_createFromDesc(desc,
+                                                       &state_input_memory),
+          "creating input device memory from the descriptor", nnapi_errno);
+      info.nn_input_memory_handle.reset(state_input_memory);
+
+      ANeuralNetworksMemory* state_output_memory = nullptr;
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context,
+          nnapi_->ANeuralNetworksMemory_createFromDesc(desc,
+                                                       &state_output_memory),
+          "creating output device memory from the descriptor", nnapi_errno);
+      info.nn_output_memory_handle.reset(state_output_memory);
+      nnapi_->ANeuralNetworksMemoryDesc_free(desc);
+
+      // we need a temporary buffer to sync states to raw pointers
+      TfLiteTensor* tensor = &context->tensors[tfl_index];
+      if (tensor->buffer_handle == kTfLiteNullBufferHandle) {
+        info.nn_temp_buffer.reset(
+            new NNMemory(nnapi_, "temp state tensor", info.tensor_size));
+      }
+    }
+  }
+
   nn_compilation_.reset(compilation);
 
   return kTfLiteOk;
@@ -3770,6 +3904,7 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
   // Set compilation timeout if applicable.
   const auto delegate_options =
       StatefulNnApiDelegate::GetOptions(node->delegate);
+  const auto& delegate_data = StatefulNnApiDelegate::GetData(node->delegate);
   if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI13) {
     if (delegate_options.max_execution_timeout_duration_ns > 0) {
       RETURN_TFLITE_ERROR_IF_NN_ERROR(
@@ -3835,14 +3970,24 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
     }
   }
 
+  const bool use_device_memory_for_state_tensors =
+      nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI13 &&
+      delegate_data.use_device_memory_for_state_tensors &&
+      // TODO(b/174612931): Even if the model is not fully supported, we can
+      // still use device memories for state tensors if they are only used in
+      // one single partition.
+      delegate_data.single_partition_delegated &&
+      std::all_of(model_state_tfl_inputs_.begin(),
+                  model_state_tfl_inputs_.end(), [&context](int tfl_index) {
+                    TfLiteTensor* tensor = &context->tensors[tfl_index];
+                    return !IsDynamicTensor(tensor);
+                  });
+
   // Set the input tensor buffers. Note: we access tflite tensors using
   // absolute indices but NN api indices inputs by relative indices.
   int relative_input_index = 0;
 
-  const bool use_int8_asymm_signed =
-      target_sdk_version_ >= kMinSdkVersionForNNAPI13;
-
-  size_t input_offset = 0;
+  size_t input_offset_accumulator = 0;
   for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) {
     if (absolute_input_index == kTfLiteOptionalTensor) {
       continue;
@@ -3860,90 +4005,58 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
       input_nn_operand_type_ptr = &input_nn_operand_type;
     }
     if (tensor->allocation_type != kTfLiteMmapRo) {
-      if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
-          tensor->buffer_handle < tensor_memory_map_->size()) {
-        RETURN_TFLITE_ERROR_IF_NN_ERROR_FOR_TENSOR(
-            context,
-            nnapi_->ANeuralNetworksExecution_setInputFromMemory(
-                execution, relative_input_index, input_nn_operand_type_ptr,
-                tensor_memory_map_->at(tensor->buffer_handle).memory, 0,
-                tensor->bytes),
-            "associating NNAPI execution input with a memory object", tensor,
-            nnapi_errno);
-        relative_input_index++;
-        continue;
-      }
-      int tensor_size = 0;
-      if (ann_type_equivalent != kTfLiteNoType) {
-        const auto num_elements = NumElements(tensor);
-        uint8_t* input_ptr = nn_input_memory_->get_data_ptr() + input_offset;
-        if (tensor->type == kTfLiteUInt8 &&
-            ann_type_equivalent == kTfLiteInt32) {
-          for (int i = 0; i < num_elements; ++i) {
-            reinterpret_cast<int32_t*>(input_ptr)[i] =
-                static_cast<const int32_t>(tensor->data.uint8[i]);
-          }
-        } else if (tensor->type == kTfLiteInt8 &&
-                   ann_type_equivalent == kTfLiteUInt8) {
-          // Explicitly convert int8 values to uint8 values.
-          for (int i = 0; i < num_elements; ++i) {
-            input_ptr[i] = static_cast<const uint8_t>(
-                static_cast<int32_t>(tensor->data.int8[i]) + 128);
-          }
-        } else if (tensor->type == kTfLiteInt8 &&
-                   ann_type_equivalent == kTfLiteInt32) {
-          if (use_int8_asymm_signed) {
-            for (int i = 0; i < num_elements; ++i) {
-              reinterpret_cast<int32_t*>(input_ptr)[i] =
-                  static_cast<const int32_t>(tensor->data.int8[i]);
-            }
-          } else {
-            for (int i = 0; i < num_elements; ++i) {
-              reinterpret_cast<int32_t*>(input_ptr)[i] =
-                  static_cast<const int32_t>(tensor->data.int8[i]) + 128;
-            }
-          }
-        } else {
-          context->ReportError(
-              context,
-              "NN API Delegate: unsupported tensor types conversion: "
-              "from type code %d to type code %d.\n",
-              tensor->type, ann_type_equivalent);
-          return kTfLiteError;
-        }
-        size_t type_size;
-        TF_LITE_ENSURE_OK(
-            context, GetSizeOfType(context, ann_type_equivalent, &type_size));
-        tensor_size = NumElements(tensor) * type_size;
-        RETURN_TFLITE_ERROR_IF_NN_ERROR_FOR_TENSOR(
-            context,
-            nnapi_->ANeuralNetworksExecution_setInputFromMemory(
-                execution, relative_input_index, input_nn_operand_type_ptr,
-                nn_input_memory_->get_handle(), input_offset, tensor_size),
-            "associating NNAPI execution input with a memory object", tensor,
-            nnapi_errno);
+      ANeuralNetworksMemory* input_memory_handle = nullptr;
+      uint32_t input_offset = 0;
+      uint32_t input_length = 0;
+      const bool is_state_tensor =
+          nn_state_tensor_info_map_.count(absolute_input_index) > 0;
+      if (is_state_tensor && use_device_memory_for_state_tensors &&
+          // If the client requests to sync states to device, we will use the
+          // shared memory directly as input instead of explicitly copying into
+          // the device memory.
+          !delegate_data.sync_states_to_device) {
+        const auto& state_tensor_info =
+            nn_state_tensor_info_map_.at(absolute_input_index);
+        input_memory_handle = state_tensor_info.nn_input_memory_handle.get();
+        input_offset = 0;
+        input_length = 0;
+      } else if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
+                 tensor->buffer_handle < tensor_memory_map_->size()) {
+        input_memory_handle =
+            tensor_memory_map_->at(tensor->buffer_handle).memory;
+        input_offset = 0;
+        input_length = tensor->bytes;
       } else {
-        // copy data to pre-allocated shared memory.
-        memcpy(nn_input_memory_->get_data_ptr() + input_offset,
-               tensor->data.raw, tensor->bytes);
-        RETURN_TFLITE_ERROR_IF_NN_ERROR_FOR_TENSOR(
+        int tensor_size = 0;
+        // copy or convert tensor data to pre-allocated shared memory.
+        const bool use_int8_asymm_signed =
+            target_sdk_version_ >= kMinSdkVersionForNNAPI13;
+        TF_LITE_ENSURE_OK(
             context,
-            nnapi_->ANeuralNetworksExecution_setInputFromMemory(
-                execution, relative_input_index, input_nn_operand_type_ptr,
-                nn_input_memory_->get_handle(), input_offset, tensor->bytes),
-            "associating NNAPI execution input with a memory object", tensor,
-            nnapi_errno);
-        tensor_size = tensor->bytes;
+            CopyOrConvertInputData(
+                context, ann_type_equivalent, use_int8_asymm_signed, tensor,
+                nn_input_memory_->get_data_ptr() + input_offset_accumulator,
+                &tensor_size));
+        input_memory_handle = nn_input_memory_->get_handle();
+        input_offset = input_offset_accumulator;
+        input_length = tensor_size;
+        input_offset_accumulator += tensor_size;
+        input_offset_accumulator += getNumPaddingBytes(tensor_size);
       }
-      input_offset += tensor_size;
-      input_offset += getNumPaddingBytes(tensor_size);
+      RETURN_TFLITE_ERROR_IF_NN_ERROR_FOR_TENSOR(
+          context,
+          nnapi_->ANeuralNetworksExecution_setInputFromMemory(
+              execution, relative_input_index, input_nn_operand_type_ptr,
+              input_memory_handle, input_offset, input_length),
+          "associating NNAPI execution input with a memory object", tensor,
+          nnapi_errno);
       relative_input_index++;
     }
   }
 
   // Set the output tensor buffers.
   int relative_output_index = 0;
-  size_t output_offset = 0;
+  size_t output_offset_accumulator = 0;
   for (auto output_index : TfLiteIntArrayView(node->outputs)) {
     // If the NNAPI implementation doesn't have some of the outputs
     // they are left unmapped and we should not try to read their value here
@@ -3977,11 +4090,12 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
           context,
           nnapi_->ANeuralNetworksExecution_setOutputFromMemory(
               execution, relative_output_index, output_nn_operand_type_ptr,
-              nn_output_memory_->get_handle(), output_offset, tensor->bytes),
+              nn_output_memory_->get_handle(), output_offset_accumulator,
+              tensor->bytes),
           "associating NNAPI execution output to a memory object", tensor,
           nnapi_errno);
-      output_offset += tensor->bytes;
-      output_offset += getNumPaddingBytes(tensor->bytes);
+      output_offset_accumulator += tensor->bytes;
+      output_offset_accumulator += getNumPaddingBytes(tensor->bytes);
     }
     relative_output_index++;
   }
@@ -3990,16 +4104,27 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
   // current invocation.
   for (size_t i = 0; i < model_state_tfl_inputs_.size(); i++) {
     int state_tensor_idx = model_state_tfl_inputs_[i];
-    TfLiteTensor* tensor = &context->tensors[state_tensor_idx];
-    // Here we are using a deep copy for state_in tensors so that we are not
-    // reading and writing into the same buffer during a invocation.
-    // TODO(b/110369471): using double shared buffer to minimize the copies.
-    RETURN_TFLITE_ERROR_IF_NN_ERROR(
-        context,
-        nnapi_->ANeuralNetworksExecution_setOutput(
-            execution, relative_output_index, nullptr, tensor->data.raw,
-            tensor->bytes),
-        "associating NNAPI execution output to a buffer", nnapi_errno);
+    if (use_device_memory_for_state_tensors) {
+      auto* device_memory = nn_state_tensor_info_map_.at(state_tensor_idx)
+                                .nn_output_memory_handle.get();
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context,
+          nnapi_->ANeuralNetworksExecution_setOutputFromMemory(
+              execution, relative_output_index, nullptr, device_memory, 0, 0),
+          "associating NNAPI execution output with a device memory object",
+          nnapi_errno);
+    } else {
+      TfLiteTensor* tensor = &context->tensors[state_tensor_idx];
+      // Here we are using a deep copy for state_in tensors so that we are not
+      // reading and writing into the same buffer during a invocation.
+      // TODO(b/110369471): using double shared buffer to minimize the copies.
+      RETURN_TFLITE_ERROR_IF_NN_ERROR(
+          context,
+          nnapi_->ANeuralNetworksExecution_setOutput(
+              execution, relative_output_index, nullptr, tensor->data.raw,
+              tensor->bytes),
+          "associating NNAPI execution output to a buffer", nnapi_errno);
+    }
     relative_output_index++;
   }
   // Invoke ANN in blocking fashion.
@@ -4022,39 +4147,70 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
   }
 
   // copy results from shared memory to the destination.
-  output_offset = 0;
+  output_offset_accumulator = 0;
   for (auto output_index : TfLiteIntArrayView(node->outputs)) {
     TfLiteTensor* tensor = &context->tensors[output_index];
     if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
       continue;
     }
-    TfLiteType ann_type_equivalent =
+    const TfLiteType ann_type_equivalent =
         operand_mapping_.lite_index_to_ann_type_conversion(output_index);
-    if (tensor->type == kTfLiteInt8 && ann_type_equivalent == kTfLiteUInt8) {
-      // Explicitly convert uint8 values to int8 values.
-      uint8_t* output_ptr = reinterpret_cast<uint8_t*>(
-          nn_output_memory_->get_data_ptr() + output_offset);
-      const auto num_elements = NumElements(tensor);
-      for (int i = 0; i < num_elements; ++i) {
-        output_ptr[i] =
-            static_cast<uint8_t>(static_cast<int32_t>(output_ptr[i]) - 128);
+    TF_LITE_ENSURE_OK(
+        context, CopyOrConvertOutputData(ann_type_equivalent,
+                                         nn_output_memory_->get_data_ptr() +
+                                             output_offset_accumulator,
+                                         tensor));
+    output_offset_accumulator += tensor->bytes;
+    output_offset_accumulator += getNumPaddingBytes(tensor->bytes);
+  }
+
+  // sync state tensors from device memories
+  if (use_device_memory_for_state_tensors &&
+      delegate_data.sync_states_from_device) {
+    for (auto& [tfl_index, info] : nn_state_tensor_info_map_) {
+      TfLiteTensor* tensor = &context->tensors[tfl_index];
+      if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
+          tensor->buffer_handle < tensor_memory_map_->size()) {
+        RETURN_TFLITE_ERROR_IF_NN_ERROR(
+            context,
+            nnapi_->ANeuralNetworksMemory_copy(
+                info.nn_output_memory_handle.get(),
+                tensor_memory_map_->at(tensor->buffer_handle).memory),
+            "syncing device memory from device", nnapi_errno);
+      } else {
+        // For pointer tensor data, we need to copy twice:
+        // 1. device memory -> shared memory
+        // 2. shared memory -> raw pointer
+        // The second copy may also need type conversion from uint8 -> int8.
+        RETURN_TFLITE_ERROR_IF_NN_ERROR(context,
+                                        nnapi_->ANeuralNetworksMemory_copy(
+                                            info.nn_output_memory_handle.get(),
+                                            info.nn_temp_buffer->get_handle()),
+                                        "syncing device memory from device",
+                                        nnapi_errno);
+        const TfLiteType ann_type_equivalent =
+            operand_mapping_.lite_index_to_ann_type_conversion(tfl_index);
+        TF_LITE_ENSURE_OK(context,
+                          CopyOrConvertOutputData(
+                              ann_type_equivalent,
+                              info.nn_temp_buffer->get_data_ptr(), tensor));
       }
     }
-    memcpy(tensor->data.raw, nn_output_memory_->get_data_ptr() + output_offset,
-           tensor->bytes);
-    output_offset += tensor->bytes;
-    output_offset += getNumPaddingBytes(tensor->bytes);
+  }
+
+  // swap device memory handles so that the state output of the current
+  // invocation will be used as the state input of the next invocation
+  if (use_device_memory_for_state_tensors) {
+    for (auto& [tfl_index, info] : nn_state_tensor_info_map_) {
+      std::swap(info.nn_input_memory_handle, info.nn_output_memory_handle);
+    }
   }
 
   // copy output of all output tensors in feedback_loops_ into the
   // associated input
-  for (auto feedback_loop : feedback_loops_) {
-    int output_tensor_idx;
-    int input_tensor_idx;
-    std::tie(output_tensor_idx, input_tensor_idx) = feedback_loop;
+  for (auto [output_tensor_idx, input_tensor_idx] : feedback_loops_) {
     TfLiteTensor& src = context->tensors[output_tensor_idx];
     TfLiteTensor& dest = context->tensors[input_tensor_idx];
-
     memcpy(dest.data.raw, src.data.raw, src.bytes);
   }
 
@@ -4622,6 +4778,17 @@ TfLiteStatus NNAPIDelegateKernel::BuildGraph(
   std::vector<uint32_t> outputs;
   outputs.reserve(output_tensors->size);
 
+  for (int tfl_index : model_state_tfl_inputs_) {
+    NNStateTensorInfo info = {
+        .nn_input_memory_handle =
+            std::unique_ptr<ANeuralNetworksMemory, NNFreeMemory>(
+                nullptr, NNFreeMemory(nnapi_)),
+        .nn_output_memory_handle =
+            std::unique_ptr<ANeuralNetworksMemory, NNFreeMemory>(
+                nullptr, NNFreeMemory(nnapi_))};
+    nn_state_tensor_info_map_.emplace(tfl_index, std::move(info));
+  }
+
   size_t total_input_byte_size = 0;
   // Make the TensorFlow Lite inputs and outputs to ann_indices.
   for (int i : TfLiteIntArrayView(input_tensors)) {
@@ -4631,10 +4798,6 @@ TfLiteStatus NNAPIDelegateKernel::BuildGraph(
         // The delegate might not have mapped this input (this can
         // happen if one tensor is split in several ones)
         operand_mapping_.lite_index_to_ann(i) != -1) {
-      inputs.push_back(operand_mapping_.lite_index_to_ann(i));
-      if (context->tensors[i].buffer_handle != kTfLiteNullBufferHandle) {
-        continue;
-      }
       const TfLiteType nn_type_conversion =
           operand_mapping_.lite_index_to_ann_type_conversion(i);
       int tensor_size = 0;
@@ -4646,6 +4809,15 @@ TfLiteStatus NNAPIDelegateKernel::BuildGraph(
             context, GetSizeOfType(context, nn_type_conversion, &type_size));
         tensor_size = NumElements(&context->tensors[i]) * type_size;
       }
+      if (auto it = nn_state_tensor_info_map_.find(i);
+          it != nn_state_tensor_info_map_.end()) {
+        it->second.nn_input_index = inputs.size();
+        it->second.tensor_size = tensor_size;
+      }
+      inputs.push_back(operand_mapping_.lite_index_to_ann(i));
+      if (context->tensors[i].buffer_handle != kTfLiteNullBufferHandle) {
+        continue;
+      }
       total_input_byte_size += tensor_size;
       total_input_byte_size += getNumPaddingBytes(tensor_size);
     }
@@ -4666,8 +4838,11 @@ TfLiteStatus NNAPIDelegateKernel::BuildGraph(
   }
 
   // Add state output tensors as model outputs.
-  for (int i : model_state_outputs_) {
-    outputs.push_back(i);
+  for (int i = 0; i < model_state_outputs_.size(); i++) {
+    const int tfl_index = model_state_tfl_inputs_[i];
+    const int nn_model_index = model_state_outputs_[i];
+    nn_state_tensor_info_map_.at(tfl_index).nn_output_index = outputs.size();
+    outputs.push_back(nn_model_index);
   }
 
   // Tell ANN to declare inputs/outputs
@@ -4772,6 +4947,8 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(const NnApi* nnapi,
   if (nnapi->android_sdk_version >= kMinSdkVersionForNNAPI11) {
     delegate_data_.allow_dynamic_dimensions = options.allow_dynamic_dimensions;
   }
+  delegate_data_.use_device_memory_for_state_tensors =
+      options.use_device_memory_for_state_tensors;
   TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
                        "Created TensorFlow Lite delegate for NNAPI.");
   Prepare = DoPrepare;
@@ -4814,9 +4991,17 @@ const StatefulNnApiDelegate::Options StatefulNnApiDelegate::GetOptions(
   options.max_execution_loop_timeout_duration_ns =
       delegate_data->max_execution_loop_timeout_duration_ns;
   options.allow_dynamic_dimensions = delegate_data->allow_dynamic_dimensions;
+  options.use_device_memory_for_state_tensors =
+      delegate_data->use_device_memory_for_state_tensors;
   return options;
 }
 
+const StatefulNnApiDelegate::Data& StatefulNnApiDelegate::GetData(
+    TfLiteDelegate* delegate) {
+  auto* delegate_data = reinterpret_cast<Data*>(delegate->data_);
+  return *delegate_data;
+}
+
 const std::vector<StatefulNnApiDelegate::MemoryRegistration>&
 StatefulNnApiDelegate::GetTensorMemoryMap(TfLiteDelegate* delegate) {
   auto delegate_data = reinterpret_cast<Data*>(delegate->data_);
@@ -4877,6 +5062,24 @@ int StatefulNnApiDelegate::GetNnApiErrno() const {
   return delegate_data_.nnapi_errno;
 }
 
+TfLiteStatus StatefulNnApiDelegate::SetSyncStatesToDevice(
+    bool sync_states_to_device) {
+  if (!delegate_data_.use_device_memory_for_state_tensors) {
+    return kTfLiteError;
+  }
+  delegate_data_.sync_states_to_device = sync_states_to_device;
+  return kTfLiteOk;
+}
+
+TfLiteStatus StatefulNnApiDelegate::SetSyncStatesFromDevice(
+    bool sync_states_from_device) {
+  if (!delegate_data_.use_device_memory_for_state_tensors) {
+    return kTfLiteError;
+  }
+  delegate_data_.sync_states_from_device = sync_states_from_device;
+  return kTfLiteOk;
+}
+
 // static
 TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
     TfLiteContext* context, TfLiteDelegate* delegate, const NnApi* nnapi,
@@ -4908,9 +5111,9 @@ TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
                                    supported_partition_nodes.begin(),
                                    supported_partition_nodes.end());
 
-    bool model_fully_supported = (supported_partition_nodes.size() ==
-                                  partition_params.nodes_to_replace->size);
-    if (model_fully_supported) {
+    bool single_partition_delegated = (supported_partition_nodes.size() ==
+                                       partition_params.nodes_to_replace->size);
+    if (single_partition_delegated) {
       delegate_data->CacheDelegateKernel(&partition_params,
                                          kernel_state.release());
     }
@@ -5125,6 +5328,10 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
                                    params_array, params_array + num_partitions),
                                &nodes_to_delegate));
 
+  if (!nodes_to_delegate.empty() && num_partitions == 1) {
+    delegate_data->single_partition_delegated = true;
+  }
+
   if (nodes_to_delegate.empty()) {
     return kTfLiteOk;
   } else {
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
index 4b12b0d0d18..dbc92f7d5a4 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h
@@ -125,6 +125,33 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
     // accelerator. This should only be enabled if the target device supports
     // dynamic dimensions of the model.
     bool allow_dynamic_dimensions = false;
+
+    // When set to true, the delegate will allocate device memory for state
+    // tensors to reduce data copying and transformation overhead. In such a
+    // case, the user must explicitly specify whether they would like to sync
+    // states between host and device before and after each invocation by
+    // SetSyncStatesToDevice and SetSyncStatesFromDevice. The following code
+    // snippet demonstrates the usage:
+    //
+    //   StatefulNnapiDelegate::Options options;
+    //   options.use_device_memory_for_state_tensors = true;
+    //   ...
+    //
+    //   for (int i = 0; i < sequence_size; i++) {
+    //     ...
+    //
+    //     // Push initial states to the device before the first invocation.
+    //     delegate->SetSyncStatesToDevice(i == 0);
+    //
+    //     // Get states data back to the host CPU buffer after the final
+    //     // invocation.
+    //     delegate->SetSyncStatesFromDevice(i == sequence_size - 1);
+    //
+    //     interpreter->Invoke();
+    //   }
+    //
+    // WARNING: This is an experimental interface that is subject to change.
+    bool use_device_memory_for_state_tensors = false;
   };
 
   // Uses default options.
@@ -186,7 +213,23 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
   // (i.e. when calling interpreter.ModifyGraphWithDelegate(delegate)).
   int GetNnApiErrno() const;
 
+  // Specifies whether the device memories should be initialized from the
+  // content of CPU buffers of state tensors before the execution or not.
+  // Will return an error if the delegate is not initialized with
+  // use_device_memory_for_state_tensors set to true.
+  // WARNING: This is an experimental interface that is subject to change.
+  TfLiteStatus SetSyncStatesToDevice(bool sync_states_to_device);
+
+  // Specifies whether the device memories should be copied to the content of
+  // CPU buffers of state tensors after the execution or not.
+  // Will return an error if the delegate is not initialized with
+  // use_device_memory_for_state_tensors set to true.
+  // WARNING: This is an experimental interface that is subject to change.
+  TfLiteStatus SetSyncStatesFromDevice(bool sync_states_from_device);
+
  private:
+  friend NNAPIDelegateKernel;
+
   // Encapsulates all delegate data.
   struct Data {
     // Pointer to NNAPI implementation to be used by this delegate as
@@ -235,6 +278,17 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
     uint64_t max_execution_loop_timeout_duration_ns = 0;
     // Whether to allow dynamic dimension sizes without re-compilation.
     bool allow_dynamic_dimensions = false;
+    // When set to true, the delegate will allocate device memories for state
+    // tensors to reduce data copying and transformation overhead.
+    bool use_device_memory_for_state_tensors = false;
+    // When set to true, the device memories will be initialized from the
+    // content of CPU buffers of state tensors before the execution.
+    bool sync_states_to_device = false;
+    // When set to true, the device memories will be copied to the content of
+    // CPU buffers of state tensors after the execution.
+    bool sync_states_from_device = false;
+    // Whether the model is fully supported by the delegate.
+    bool single_partition_delegated = false;
 
     explicit Data(const NnApi* nnapi);
     ~Data();
@@ -248,6 +302,9 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
         const TfLiteDelegateParams* delegate_params);
   };
 
+  // Returns the delegate data.
+  static const Data& GetData(TfLiteDelegate* delegate);
+
   // Implements TfLiteDelegate::Prepare. Please refer to TFLiteDelegate
   // documentation for more info.
   static TfLiteStatus DoPrepare(TfLiteContext* context,
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
index 36c1dd32efb..60c32a1ef0f 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h
@@ -22,6 +22,7 @@ limitations under the License.
 #include "tensorflow/lite/allocation.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h"
 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
 
 namespace tflite {
@@ -154,6 +155,18 @@ class NNFreeExecution {
   // NnApi instance to use. Not owned by this object.
   const NnApi* nnapi_;
 };
+// RAII NN API Memory Destructor for use with std::unique_ptr
+class NNFreeMemory {
+ public:
+  explicit NNFreeMemory(const NnApi* nnapi) : nnapi_(nnapi) {}
+  void operator()(ANeuralNetworksMemory* memory) {
+    nnapi_->ANeuralNetworksMemory_free(memory);
+  }
+
+ private:
+  // NnApi instance to use. Not owned by this object.
+  const NnApi* nnapi_;
+};
 
 // Manage NNAPI shared memory handle
 class NNMemory {
@@ -175,6 +188,19 @@ class NNMemory {
   ANeuralNetworksMemory* nn_memory_handle_ = nullptr;
 };
 
+// Basic info and NN device memory handles for state tensors.
+struct NNStateTensorInfo {
+  uint32_t nn_input_index = 0;
+  uint32_t nn_output_index = 0;
+  // The size of the NN state tensor after applying any potential data type
+  // conversion.
+  int tensor_size = 0;
+  std::unique_ptr<ANeuralNetworksMemory, NNFreeMemory> nn_input_memory_handle;
+  std::unique_ptr<ANeuralNetworksMemory, NNFreeMemory> nn_output_memory_handle;
+  // The shared memory used to sync the state from the device.
+  std::unique_ptr<NNMemory> nn_temp_buffer;
+};
+
 
 enum class NNAPIValidationFailureType : int {
   // The operator is not supported by either NNAPI or the NNAPI Delegate.
@@ -340,6 +366,9 @@ class NNAPIDelegateKernel {
   // data available for TFLite model users
   std::vector<std::tuple<int, int>> feedback_loops_;
 
+  // TfLite index -> state tensor info.
+  std::map<int, NNStateTensorInfo> nn_state_tensor_info_map_;
+
   std::unique_ptr<NNMemory> nn_input_memory_;
   std::unique_ptr<NNMemory> nn_output_memory_;
 
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
index 16e7a260961..c1a3923de4d 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -2718,24 +2718,15 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
  public:
   RNNOpModel(int batches, int units, int size,
              const TensorType weights = TensorType_FLOAT32,
+             const TensorType recurrent_weights = TensorType_FLOAT32) {
+    Init(batches, units, size, weights, recurrent_weights);
+  }
+
+  RNNOpModel(const StatefulNnApiDelegate::Options& options, int batches,
+             int units, int size, const TensorType weights = TensorType_FLOAT32,
              const TensorType recurrent_weights = TensorType_FLOAT32)
-      : batches_(batches), units_(units), input_size_(size) {
-    input_ = AddInput(TensorType_FLOAT32);
-    weights_ = AddInput(weights);
-    recurrent_weights_ = AddInput(recurrent_weights);
-    bias_ = AddInput(TensorType_FLOAT32);
-    hidden_state_ = AddVariableInput(TensorType_FLOAT32);
-    output_ = AddOutput(TensorType_FLOAT32);
-    SetBuiltinOp(
-        BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
-        CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
-    BuildInterpreterWithNNAPI({
-        {batches_, input_size_},  // input tensor
-        {units_, input_size_},    // weights tensor
-        {units_, units_},         // recurrent weights tensor
-        {units_},                 // bias tensor
-        {batches_, units_}        // hidden state tensor
-    });
+      : SingleOpModelWithNNAPI(options) {
+    Init(batches, units, size, weights, recurrent_weights);
   }
 
   void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -2756,8 +2747,16 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
     PopulateTensor(input_, offset, begin, end);
   }
 
+  void SetHiddenState(const std::vector<float>& data) {
+    PopulateTensor(hidden_state_, data);
+  }
+
   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
 
+  std::vector<float> GetHiddenState() {
+    return ExtractVector<float>(hidden_state_);
+  }
+
   int input_size() { return input_size_; }
   int num_units() { return units_; }
   int num_batches() { return batches_; }
@@ -2773,8 +2772,50 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
   int batches_;
   int units_;
   int input_size_;
+
+ private:
+  // Performs initialization logic shared across all constructors.
+  void Init(int batches, int units, int size, const TensorType weights,
+            const TensorType recurrent_weights) {
+    batches_ = batches;
+    units_ = units;
+    input_size_ = size;
+    input_ = AddInput(TensorType_FLOAT32);
+    weights_ = AddInput(weights);
+    recurrent_weights_ = AddInput(recurrent_weights);
+    bias_ = AddInput(TensorType_FLOAT32);
+    hidden_state_ = AddVariableInput(TensorType_FLOAT32);
+    output_ = AddOutput(TensorType_FLOAT32);
+    SetBuiltinOp(
+        BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
+        CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
+    BuildInterpreterWithNNAPI({
+        {batches_, input_size_},  // input tensor
+        {units_, input_size_},    // weights tensor
+        {units_, units_},         // recurrent weights tensor
+        {units_},                 // bias tensor
+        {batches_, units_}        // hidden state tensor
+    });
+  }
 };
 
+static void InvokeAndTestSingleRnnStep(int step_index, RNNOpModel* rnn) {
+  float* batch_start = rnn_input + step_index * rnn->input_size();
+  float* batch_end = batch_start + rnn->input_size();
+  rnn->SetInput(0, batch_start, batch_end);
+  rnn->SetInput(rnn->input_size(), batch_start, batch_end);
+
+  rnn->Invoke();
+
+  float* golden_start = rnn_golden_output + step_index * rnn->num_units();
+  float* golden_end = golden_start + rnn->num_units();
+  std::vector<float> expected;
+  expected.insert(expected.end(), golden_start, golden_end);
+  expected.insert(expected.end(), golden_start, golden_end);
+
+  EXPECT_THAT(rnn->GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+}
+
 TEST(NNAPIDelegate, RnnBlackBoxTest) {
   RNNOpModel rnn(2, 16, 8);
   rnn.SetWeights(rnn_weights);
@@ -2785,20 +2826,66 @@ TEST(NNAPIDelegate, RnnBlackBoxTest) {
                                   (rnn.input_size() * rnn.num_batches());
 
   for (int i = 0; i < input_sequence_size; i++) {
-    float* batch_start = rnn_input + i * rnn.input_size();
-    float* batch_end = batch_start + rnn.input_size();
-    rnn.SetInput(0, batch_start, batch_end);
-    rnn.SetInput(rnn.input_size(), batch_start, batch_end);
+    InvokeAndTestSingleRnnStep(i, &rnn);
+  }
+}
 
-    rnn.Invoke();
+TEST(NNAPIDelegate, RnnDeviceMemoryBasicTest) {
+  StatefulNnApiDelegate::Options options;
+  options.use_device_memory_for_state_tensors = true;
 
-    float* golden_start = rnn_golden_output + i * rnn.num_units();
-    float* golden_end = golden_start + rnn.num_units();
-    std::vector<float> expected;
-    expected.insert(expected.end(), golden_start, golden_end);
-    expected.insert(expected.end(), golden_start, golden_end);
+  RNNOpModel rnn(options, 2, 16, 8);
+  rnn.SetWeights(rnn_weights);
+  rnn.SetBias(rnn_bias);
+  rnn.SetRecurrentWeights(rnn_recurrent_weights);
 
-    EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+  auto* delegate = rnn.GetDelegate();
+  const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
+                                  (rnn.input_size() * rnn.num_batches());
+
+  // Only sync the state to device in the first invocation, all subsequent
+  // states are kept inside the driver.
+  for (int i = 0; i < input_sequence_size; i++) {
+    delegate->SetSyncStatesToDevice(i == 0);
+    InvokeAndTestSingleRnnStep(i, &rnn);
+  }
+}
+
+TEST(NNAPIDelegate, RnnDeviceMemorySyncTest) {
+  StatefulNnApiDelegate::Options options;
+  options.use_device_memory_for_state_tensors = true;
+
+  RNNOpModel rnn(options, 2, 16, 8);
+  rnn.SetWeights(rnn_weights);
+  rnn.SetBias(rnn_bias);
+  rnn.SetRecurrentWeights(rnn_recurrent_weights);
+
+  auto* delegate = rnn.GetDelegate();
+  const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
+                                  (rnn.input_size() * rnn.num_batches());
+  const int sync_output_index = input_sequence_size / 2;
+
+  // The following steps test SetSyncStatesFromDevice and SetSyncStatesToDevice:
+  // 1. Invoke RNN sequence until sync_output_index;
+  // 2. Extract the hidden output state at sync_output_index by
+  //    SetSyncStatesFromDevice(true);
+  // 3. Continue RNN sequence until the end;
+  // 4. Reset the hidden state by SetSyncStatesToDevice(true), the state should
+  //    go back to sync_output_index;
+  // 5. Continue RNN sequence from sync_output_index + 1 until the end.
+  std::vector<float> hidden_state_data;
+  for (int i = 0; i < input_sequence_size; i++) {
+    delegate->SetSyncStatesToDevice(i == 0);
+    delegate->SetSyncStatesFromDevice(i == sync_output_index);
+    InvokeAndTestSingleRnnStep(i, &rnn);
+    if (i == sync_output_index) {
+      hidden_state_data = rnn.GetHiddenState();
+    }
+  }
+  rnn.SetHiddenState(hidden_state_data);
+  for (int i = sync_output_index + 1; i < input_sequence_size; i++) {
+    delegate->SetSyncStatesToDevice(i == (sync_output_index + 1));
+    InvokeAndTestSingleRnnStep(i, &rnn);
   }
 }