From c37b2f11529182853bbf234232b4da4b5ff476d6 Mon Sep 17 00:00:00 2001
From: Fergus Henderson <fergus@google.com>
Date: Fri, 4 Dec 2020 04:07:24 -0800
Subject: [PATCH] Fix some old TODO comments: delete methods of Subgraph that
 expose mutable access to the private `tensors_` vector and to the private
 `nodes_and_registrations_` vector, replacing them with functions that provide
 mutable access to the individual tensors/nodes/registrations, and apply minor
 refactorings to the code that was previously using these methods to use other
 methods instead.

PiperOrigin-RevId: 345649125
Change-Id: I1b38c22421d62ef85c7a1eadf9662234995c90fe
---
 RELEASE.md                                    | 10 +++
 tensorflow/lite/core/subgraph.cc              | 17 ++---
 tensorflow/lite/core/subgraph.h               | 11 ----
 .../gpu/common/model_builder_test.cc          | 63 ++++++++++---------
 tensorflow/lite/interpreter.cc                |  6 +-
 5 files changed, 55 insertions(+), 52 deletions(-)

diff --git a/RELEASE.md b/RELEASE.md
index 90d93640fe9..6340db05345 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -42,6 +42,16 @@
         renamed `tf.function(jit_compile=True)`.
 
 *   `tf.lite`:
+    *   class `tflite::Subgraph`:
+        *   Removed the `tensors()` method and the non-const overload of the
+            `nodes_and_registration()` method, both of which were previously
+            documented as temporary and to be removed.
+            *   Uses of `tensors()` can be replaced by calling the existing
+                methods `tensors_size()` and `tensor(int)`.
+            *   Uses of the non-const overload of `nodes_and_registration`
+                can be replaced by calling the existing methods `nodes_size()`
+                and `context()`, and then calling the `GetNodeAndRegistration`
+                method in the `TfLiteContext` returned by `context()`.
     *   NNAPI
         *   Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
             *   Use `NnApiDelegate()` and related delegate configuration methods
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index 1deda07d397..14e2a02c605 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -162,9 +162,9 @@ class InterpreterInfo : public GraphInfo {
  public:
   explicit InterpreterInfo(Subgraph* subgraph) : subgraph_(subgraph) {}
 
-  size_t num_tensors() const override { return subgraph_->tensors().size(); }
+  size_t num_tensors() const override { return subgraph_->tensors_size(); }
   TfLiteTensor* tensor(size_t index) override {
-    return &subgraph_->tensors()[index];
+    return subgraph_->tensor(index);
   }
   size_t num_execution_nodes() const override {
     return subgraph_->execution_plan().size();
@@ -218,7 +218,7 @@ Subgraph::Subgraph(ErrorReporter* error_reporter,
 
   // Reserve some space for the tensors to avoid excessive resizing.
   tensors_.reserve(kTensorsReservedCapacity);
-  nodes_and_registration().reserve(kTensorsReservedCapacity);
+  nodes_and_registration_.reserve(kTensorsReservedCapacity);
   // Invalid to call these except from TfLiteDelegate
   SwitchToKernelContext();
 }
@@ -961,12 +961,13 @@ TfLiteStatus Subgraph::PrepareOpsAndTensors() {
   // overhead should be minimal since the number of custom-allocated tensors
   // will typically be low.
   for (int i = 0; i < custom_allocations_.size(); ++i) {
-    auto idx_and_alloc = custom_allocations_[i];
-    auto& tensor = tensors()[idx_and_alloc.first];
-    const auto& alloc = idx_and_alloc.second;
-    TF_LITE_ENSURE(context(), tensor.allocation_type == kTfLiteCustom);
+    auto index_and_alloc = custom_allocations_[i];
+    TfLiteTensor* tensor_at_index = tensor(index_and_alloc.first);
+    const auto& alloc = index_and_alloc.second;
+    TF_LITE_ENSURE(context(),
+                   tensor_at_index->allocation_type == kTfLiteCustom);
     TF_LITE_ENSURE_STATUS(
-        ValidateCustomAllocationForTensor(context(), &tensor, alloc));
+        ValidateCustomAllocationForTensor(context(), tensor_at_index, alloc));
   }
 
   next_execution_plan_index_to_plan_allocation_ =
diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h
index 1d295e2bc08..d4feaa91ced 100644
--- a/tensorflow/lite/core/subgraph.h
+++ b/tensorflow/lite/core/subgraph.h
@@ -180,17 +180,6 @@ class Subgraph {
   // Return read-only vector of node indices in the order of execution.
   const std::vector<int>& execution_plan() const { return execution_plan_; }
 
-  // Mutable form of tensors (TEMPORARY for refactor).
-  // TODO(b/119495520): remove when refactoring complete.
-  std::vector<TfLiteTensor>& tensors() { return tensors_; }
-
-  // Mutable form of nodes_and_registration (TEMPORARY for refactor).
-  // TODO(b/119495520): remove when refactoring complete.
-  std::vector<std::pair<TfLiteNode, TfLiteRegistration>>&
-  nodes_and_registration() {
-    return nodes_and_registration_;
-  }
-
   const std::vector<std::pair<TfLiteNode, TfLiteRegistration>>&
   nodes_and_registration() const {
     return nodes_and_registration_;
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
index 9bc848b9210..5bf0d60a6aa 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
@@ -141,12 +141,28 @@ class DelegatedInterpreter {
   }
 
   // Get the TfLiteContext to be mocked for swapping out functions that have to
-  // be called inside delegate (i.e. in delegat kernel mode).
+  // be called inside delegate (i.e. in delegate kernel mode).
   TfLiteContext* context() { return interpreter_.primary_subgraph().context(); }
 
-  std::vector<std::pair<TfLiteNode, TfLiteRegistration>>&
-  nodes_and_registration() {
-    return interpreter_.primary_subgraph().nodes_and_registration();
+  // node(int) and registration(int) are used to implement
+  // GetNodeAndRegistration.  We can't implement those using
+  //   TfLiteContext *context = interpreter_.primary_subgraph().context();
+  //   context->GetNodeAndRegistration(context, &node, &registration);
+  // here, because calling GetNodeAndRegistration from within it's own
+  // implementation would lead to an infinite loop.
+  // Instead, we just call node_and_registration and use a const_cast.
+  // These const_casts are a bit ugly, but I think less ugly than exposing
+  // the private GetNodeAndRegistration method in Subgraph as public,
+  // or making this class a friend of Subgraph.
+  TfLiteNode* node(int index) {
+    const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration =
+        interpreter_.primary_subgraph().node_and_registration(index);
+    return const_cast<TfLiteNode*>(&node_and_registration->first);
+  }
+  TfLiteRegistration* registration(int index) {
+    const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration =
+        interpreter_.primary_subgraph().node_and_registration(index);
+    return const_cast<TfLiteRegistration*>(&node_and_registration->second);
   }
 
   TfLiteIntArray* exec_plan() const { return exec_plan_; }
@@ -280,10 +296,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) {
   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
                                        TfLiteNode** node,
                                        TfLiteRegistration** registration) {
-    auto& node_and_reg =
-        interpreter_fp16_add_op->nodes_and_registration()[node_index];
-    *node = &node_and_reg.first;
-    *registration = &node_and_reg.second;
+    *node = interpreter_fp16_add_op->node(node_index);
+    *registration = interpreter_fp16_add_op->registration(node_index);
     return kTfLiteOk;
   };
   context->PreviewDelegatePartitioning =
@@ -346,10 +360,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceRejectsFp16DequantizeNodes) {
   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
                                        TfLiteNode** node,
                                        TfLiteRegistration** registration) {
-    auto& node_and_reg =
-        interpreter_fp16_gt_op->nodes_and_registration()[node_index];
-    *node = &node_and_reg.first;
-    *registration = &node_and_reg.second;
+    *node = interpreter_fp16_gt_op->node(node_index);
+    *registration = interpreter_fp16_gt_op->registration(node_index);
     return kTfLiteOk;
   };
   context->PreviewDelegatePartitioning =
@@ -462,9 +474,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) {
   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
                                        TfLiteNode** node,
                                        TfLiteRegistration** registration) {
-    auto& node_and_reg = interpreter_fp32->nodes_and_registration()[node_index];
-    *node = &node_and_reg.first;
-    *registration = &node_and_reg.second;
+    *node = interpreter_fp32->node(node_index);
+    *registration = interpreter_fp32->registration(node_index);
     return kTfLiteOk;
   };
   context->PreviewDelegatePartitioning =
@@ -630,10 +641,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceMultiplePartitions) {
   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
                                        TfLiteNode** node,
                                        TfLiteRegistration** registration) {
-    auto& node_and_reg =
-        interpreter2_fp32->nodes_and_registration()[node_index];
-    *node = &node_and_reg.first;
-    *registration = &node_and_reg.second;
+    *node = interpreter2_fp32->node(node_index);
+    *registration = interpreter2_fp32->registration(node_index);
     return kTfLiteOk;
   };
   context->PreviewDelegatePartitioning =
@@ -845,9 +854,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition) {
   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
                                        TfLiteNode** node,
                                        TfLiteRegistration** registration) {
-    auto& node_and_reg = interpreter_mn->nodes_and_registration()[node_index];
-    *node = &node_and_reg.first;
-    *registration = &node_and_reg.second;
+    *node = interpreter_mn->node(node_index);
+    *registration = interpreter_mn->registration(node_index);
     return kTfLiteOk;
   };
   context->PreviewDelegatePartitioning =
@@ -917,9 +925,8 @@ TEST(ModelBuilderTest,
   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
                                        TfLiteNode** node,
                                        TfLiteRegistration** registration) {
-    auto& node_and_reg = interpreter_mn2->nodes_and_registration()[node_index];
-    *node = &node_and_reg.first;
-    *registration = &node_and_reg.second;
+    *node = interpreter_mn2->node(node_index);
+    *registration = interpreter_mn2->registration(node_index);
     return kTfLiteOk;
   };
 
@@ -1121,10 +1128,8 @@ TEST(ModelBuilderTest, GetOpsToReplace_AllowQuantOps) {
   context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
                                        TfLiteNode** node,
                                        TfLiteRegistration** registration) {
-    auto& node_and_reg =
-        interpreter_quant->nodes_and_registration()[node_index];
-    *node = &node_and_reg.first;
-    *registration = &node_and_reg.second;
+    *node = interpreter_quant->node(node_index);
+    *registration = interpreter_quant->registration(node_index);
     return kTfLiteOk;
   };
   context->PreviewDelegatePartitioning =
diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc
index f97d6d805d0..b5ef31edc4d 100644
--- a/tensorflow/lite/interpreter.cc
+++ b/tensorflow/lite/interpreter.cc
@@ -411,8 +411,7 @@ TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
                                           TfLiteBufferHandle buffer_handle,
                                           TfLiteDelegate* delegate) {
   TF_LITE_ENSURE(context_, tensor_index < tensors_size());
-  std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
-  TfLiteTensor* tensor = &tensors[tensor_index];
+  TfLiteTensor* tensor = primary_subgraph().tensor(tensor_index);
 
   TF_LITE_ENSURE(context_,
                  tensor->delegate == nullptr || tensor->delegate == delegate);
@@ -431,8 +430,7 @@ TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
                                           TfLiteBufferHandle* buffer_handle,
                                           TfLiteDelegate** delegate) {
   TF_LITE_ENSURE(context_, tensor_index < tensors_size());
-  std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
-  TfLiteTensor* tensor = &tensors[tensor_index];
+  TfLiteTensor* tensor = primary_subgraph().tensor(tensor_index);
 
   *delegate = tensor->delegate;
   *buffer_handle = tensor->buffer_handle;