From f93d0e2c7214940d7d277ec2c259e6f6c7094a88 Mon Sep 17 00:00:00 2001
From: Nat Jeffries <njeff@google.com>
Date: Sun, 11 Oct 2020 12:03:02 -0700
Subject: [PATCH] Remove TFLiteTensors from FloatSVDF and IntegerSVDF methods
 in shared SVDF code. Call into shared SVDF code instead of TFLM specific SVDF
 code for float reference kernel.

PiperOrigin-RevId: 336555588
Change-Id: If448e01cf0ea944229658997c5ea8ff8ec5eff2d
---
 .../lite/kernels/internal/reference/svdf.h    | 198 +++++++-----------
 tensorflow/lite/kernels/svdf.cc               |  42 +++-
 2 files changed, 113 insertions(+), 127 deletions(-)

diff --git a/tensorflow/lite/kernels/internal/reference/svdf.h b/tensorflow/lite/kernels/internal/reference/svdf.h
index bb986e4de0a..c61abf3adb5 100644
--- a/tensorflow/lite/kernels/internal/reference/svdf.h
+++ b/tensorflow/lite/kernels/internal/reference/svdf.h
@@ -36,7 +36,7 @@ namespace reference_ops {
 
 static inline void ApplyTimeWeightsBiasAndActivation(
     int batch_size, int memory_size, int num_filters, int num_units, int rank,
-    const float* const __restrict__ weights_time_ptr,
+    const float* const __restrict__ weights_time_data,
     const float* const __restrict__ bias_ptr, TfLiteFusedActivation activation,
     float* const __restrict__ state_ptr, float* const __restrict__ scratch_ptr,
     float* const __restrict__ output_ptr) {
@@ -45,7 +45,7 @@ static inline void ApplyTimeWeightsBiasAndActivation(
     float* state_ptr_batch = state_ptr + b * memory_size * num_filters;
     float* scratch_ptr_batch = scratch_ptr + b * num_filters;
     tensor_utils::BatchVectorBatchVectorDotProduct(
-        weights_time_ptr, state_ptr_batch, memory_size, num_filters,
+        weights_time_data, state_ptr_batch, memory_size, num_filters,
         scratch_ptr_batch);
   }
 
@@ -74,44 +74,41 @@ static inline void ApplyTimeWeightsBiasAndActivation(
 }
 
 inline void EvalIntegerSVDF(
-    TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input_tensor,
-    const TfLiteTensor* weights_feature_tensor,
-    const TfLiteTensor* weights_time_tensor, const TfLiteTensor* bias_tensor,
-    const TfLiteSVDFParams* params, TfLiteTensor* state_tensor,
-    TfLiteTensor* output_tensor, TfLiteTensor* scratch_tensor,
-    TfLiteTensor* output_temp_tensor, int32_t scale_1_a, int scale_1_b,
-    int32_t scale_2_a, int scale_2_b, int32_t input_zp, int32_t output_zp) {
+    const TfLiteSVDFParams* params, const RuntimeShape& input_shape,
+    const int8_t* input_data, const RuntimeShape& weights_feature_shape,
+    const int8_t* weights_feature_data, const RuntimeShape& weights_time_shape,
+    const int16_t* weights_time_data, const RuntimeShape& bias_shape,
+    const int32_t* bias_data, int16_t* state_data,
+    const RuntimeShape& output_shape, int8_t* output_data,
+    int32_t* scratch_data, int32_t* output_temp_data, int32_t scale_1_a,
+    int scale_1_b, int32_t scale_2_a, int scale_2_b, int32_t input_zp,
+    int32_t output_zp) {
   const int n_rank = params->rank;
-  const int n_batch = input_tensor->dims->data[0];
-  const int n_input = input_tensor->dims->data[1];
-  const int n_filter = weights_feature_tensor->dims->data[0];
+  const int n_batch = input_shape.Dims(0);
+  const int n_input = input_shape.Dims(1);
+  const int n_filter = weights_feature_shape.Dims(0);
   const int n_unit = n_filter / n_rank;
-  const int n_memory = weights_time_tensor->dims->data[1];
-
-  int16_t* const state_ptr = GetTensorData<int16_t>(state_tensor);
+  const int n_memory = weights_time_shape.Dims(1);
 
   // Left shift the activation_state.
   // std::copy is fine for overlapping ranges if the output is outside of the
   // input range. (This is not true for copy_n.)
-  std::copy(state_ptr + 1, state_ptr + n_batch * n_memory * n_filter,
-            state_ptr);
+  std::copy(state_data + 1, state_data + n_batch * n_memory * n_filter,
+            state_data);
 
   // Feature matmul.
   // Note: no need to clear the latest activation, matmul is not accumulative.
   {
-    const int8_t* input = GetTensorData<int8_t>(input_tensor);
-    const int8_t* weight_feature =
-        GetTensorData<int8_t>(weights_feature_tensor);
     const int32_t output_max = std::numeric_limits<int16_t>::max();
     const int32_t output_min = std::numeric_limits<int16_t>::min();
-    int16_t* result_in_batch = state_ptr + (n_memory - 1);
+    int16_t* result_in_batch = state_data + (n_memory - 1);
     for (int b = 0; b < n_batch; b++) {
-      const int8_t* matrix_ptr = weight_feature;
+      const int8_t* matrix_data = weights_feature_data;
       for (int r = 0; r < n_filter; r++) {
         int32_t dot_prod = 0;
-        const int8_t* vector_in_batch = input + b * n_input;
+        const int8_t* vector_in_batch = input_data + b * n_input;
         for (int c = 0; c < n_input; c++) {
-          dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp);
+          dot_prod += *matrix_data++ * (*vector_in_batch++ - input_zp);
         }
         dot_prod =
             MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b);
@@ -131,171 +128,134 @@ inline void EvalIntegerSVDF(
   // Time.
   {
     for (int b = 0; b < n_batch; ++b) {
-      const int16_t* state_ptr_batch = state_ptr + b * n_memory * n_filter;
-      int32_t* scratch_ptr_batch =
-          GetTensorData<int32_t>(scratch_tensor) + b * n_filter;
+      const int16_t* state_data_batch = state_data + b * n_memory * n_filter;
+      int32_t* scratch_data_batch = scratch_data + b * n_filter;
       tensor_utils::BatchVectorBatchVectorDotProduct(
-          GetTensorData<int16_t>(weights_time_tensor), state_ptr_batch,
-          n_memory, n_filter, scratch_ptr_batch);
+          weights_time_data, state_data_batch, n_memory, n_filter,
+          scratch_data_batch);
     }
   }
 
   // Reduce, add bias, rescale, activation.
   {
-    int32_t* output_temp = GetTensorData<int32_t>(output_temp_tensor);
     // Add bias.
-    if (bias_tensor) {
-      tensor_utils::VectorBatchVectorAssign(GetTensorData<int32_t>(bias_tensor),
-                                            n_unit, n_batch, output_temp);
+    if (bias_data) {
+      tensor_utils::VectorBatchVectorAssign(bias_data, n_unit, n_batch,
+                                            output_temp_data);
     } else {
-      std::fill_n(output_temp, n_batch * n_unit, 0);
+      std::fill_n(output_temp_data, n_batch * n_unit, 0);
     }
     // Reduce.
     for (int b = 0; b < n_batch; ++b) {
-      int32_t* output_temp_ptr = output_temp + b * n_unit;
-      int32_t* scratch_ptr_batch =
-          GetTensorData<int32_t>(scratch_tensor) + b * n_filter;
-      tensor_utils::ReductionSumVector(scratch_ptr_batch, output_temp_ptr,
+      int32_t* output_temp_ptr = output_temp_data + b * n_unit;
+      int32_t* scratch_data_batch = scratch_data + b * n_filter;
+      tensor_utils::ReductionSumVector(scratch_data_batch, output_temp_ptr,
                                        n_unit, n_rank);
     }
     // Rescale.
     const int32_t output_max = std::numeric_limits<int8_t>::max();
     const int32_t output_min = std::numeric_limits<int8_t>::min();
     for (int i = 0; i < n_batch * n_unit; ++i) {
-      int32_t x1 = output_temp[i];
+      int32_t x1 = output_temp_data[i];
       int32_t x2 = MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b);
       int32_t x3 = x2 + output_zp;
       int32_t x4 = std::min(std::max(output_min, x3), output_max);
-      GetTensorData<int8_t>(output_tensor)[i] = static_cast<int8_t>(x4);
+      output_data[i] = static_cast<int8_t>(x4);
     }
   }
 }
 
-inline void EvalFloatSVDF(TfLiteContext* context, TfLiteNode* node,
-                          const TfLiteTensor* input,
-                          const TfLiteTensor* weights_feature,
-                          const TfLiteTensor* weights_time,
-                          const TfLiteTensor* bias,
-                          const TfLiteSVDFParams* params, TfLiteTensor* scratch,
-                          TfLiteTensor* state, TfLiteTensor* output) {
+inline void EvalFloatSVDF(
+    const TfLiteSVDFParams* params, const RuntimeShape& input_shape,
+    const float* input_data, const RuntimeShape& weights_feature_shape,
+    const float* weights_feature_data, const RuntimeShape& weights_time_shape,
+    const float* weights_time_data, const RuntimeShape& bias_shape,
+    const float* bias_data, float* scratch_data, float* state_data,
+    const RuntimeShape& output_shape, float* output_data) {
   const int rank = params->rank;
-  const int batch_size = input->dims->data[0];
-  const int input_size = input->dims->data[1];
-  const int num_filters = weights_feature->dims->data[0];
+  const int batch_size = input_shape.Dims(0);
+  const int input_size = input_shape.Dims(1);
+  const int num_filters = weights_feature_shape.Dims(0);
   const int num_units = num_filters / rank;
-  const int memory_size = weights_time->dims->data[1];
-
-  // Raw pointers to tensor data.
-  const float* input_ptr = GetTensorData<float>(input);
-  const float* weights_feature_ptr = GetTensorData<float>(weights_feature);
-  const float* weights_time_ptr = GetTensorData<float>(weights_time);
-  const float* bias_ptr = GetTensorData<float>(bias);
-
-  float* state_ptr = GetTensorData<float>(state);
-  float* scratch_ptr = GetTensorData<float>(scratch);
-
-  float* output_ptr = GetTensorData<float>(output);
+  const int memory_size = weights_time_shape.Dims(1);
 
   // Left shift the activation_state.
   // std::copy is fine for overlapping ranges if the output is outside of the
   // input range. (This is not true for copy_n.)
-  std::copy(state_ptr + 1, state_ptr + batch_size * memory_size * num_filters,
-            state_ptr);
+  std::copy(state_data + 1, state_data + batch_size * memory_size * num_filters,
+            state_data);
 
   // Clear scratch (the matmul is accumulative).
-  std::fill_n(scratch_ptr, batch_size * num_filters, 0.0f);
+  std::fill_n(scratch_data, batch_size * num_filters, 0.0f);
 
   // Compute conv1d(inputs, weights_feature).
   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
-      weights_feature_ptr, num_filters, input_size, input_ptr, batch_size,
-      scratch_ptr);
+      weights_feature_data, num_filters, input_size, input_data, batch_size,
+      scratch_data);
 
   // Copy the latest activation from scratch into activation_state:
   // The last, i.e. (memory_size-1)th entry for each batch, and filter.
   for (int i = 0; i < batch_size * num_filters; ++i) {
-    state_ptr[i * memory_size + memory_size - 1] = scratch_ptr[i];
+    state_data[i * memory_size + memory_size - 1] = scratch_data[i];
   }
 
   ApplyTimeWeightsBiasAndActivation(
-      batch_size, memory_size, num_filters, num_units, rank, weights_time_ptr,
-      bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr);
+      batch_size, memory_size, num_filters, num_units, rank, weights_time_data,
+      bias_data, params->activation, state_data, scratch_data, output_data);
 }
 
 inline void EvalHybridSVDF(
-    TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
-    const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
-    const TfLiteTensor* bias, const TfLiteSVDFParams* params,
-    TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
-    TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output,
-    TfLiteTensor* zero_points, TfLiteTensor* row_sums, bool* compute_row_sums) {
+    const TfLiteSVDFParams* params, const RuntimeShape& input_shape,
+    const float* input_data, const RuntimeShape& weights_feature_shape,
+    const int8_t* weights_feature_data, const float weights_feature_scale,
+    const RuntimeShape& weights_time_shape, const float* weights_time_data,
+    const RuntimeShape& bias_shape, const float* bias_data, float* scratch,
+    float* scaling_factors, int8_t* quantized_input, float* state,
+    const RuntimeShape& output_shape, float* output_data, int32_t* zero_points,
+    int32_t* row_sums, bool* compute_row_sums) {
   const int rank = params->rank;
-  const int batch_size = input->dims->data[0];
-  const int input_size = input->dims->data[1];
-  const int num_filters = weights_feature->dims->data[0];
+  const int batch_size = input_shape.Dims(0);
+  const int input_size = input_shape.Dims(1);
+  const int num_filters = weights_feature_shape.Dims(0);
   const int num_units = num_filters / rank;
-  const int memory_size = weights_time->dims->data[1];
-
-  // Raw pointers to tensor data.
-  const float* input_ptr = GetTensorData<float>(input);
-  const int8_t* weights_feature_ptr = GetTensorData<int8_t>(weights_feature);
-  const float* weights_time_ptr = GetTensorData<float>(weights_time);
-  const float* bias_ptr = GetTensorData<float>(bias);
-
-  int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_quantized);
-  float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
-  float* state_ptr = GetTensorData<float>(state);
-  float* scratch_ptr = GetTensorData<float>(scratch);
-
-  float* output_ptr = GetTensorData<float>(output);
-
-  int32_t* zero_points_ptr = nullptr;
-  int32_t* row_sums_ptr = nullptr;
-  if (params->asymmetric_quantize_inputs && row_sums != nullptr) {
-    zero_points_ptr = GetTensorData<int32_t>(zero_points);
-    row_sums_ptr = GetTensorData<int32_t>(row_sums);
-  }
-
-  // Initialize the weights scale.
-  const float weights_feature_scale = weights_feature->params.scale;
+  const int memory_size = weights_time_shape.Dims(1);
 
   // Left shift the activation_state.
   // std::copy is fine for overlapping ranges if the output is outside of the
   // input range. (This is not true for copy_n.)
-  std::copy(state_ptr + 1, state_ptr + batch_size * memory_size * num_filters,
-            state_ptr);
+  std::copy(state + 1, state + batch_size * memory_size * num_filters, state);
 
   // Clear scratch (the matmul is accumulative).
-  std::fill_n(scratch_ptr, batch_size * num_filters, 0.0f);
+  std::fill_n(scratch, batch_size * num_filters, 0.0f);
 
-  if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) {
+  if (!tensor_utils::IsZeroVector(input_data, batch_size * input_size)) {
     // Quantize input from float to int8_t.
-    tensor_utils::BatchQuantizeFloats(input_ptr, batch_size, input_size,
-                                      quantized_input_ptr, scaling_factors_ptr,
-                                      zero_points_ptr,
-                                      params->asymmetric_quantize_inputs);
+    tensor_utils::BatchQuantizeFloats(
+        input_data, batch_size, input_size, quantized_input, scaling_factors,
+        zero_points, params->asymmetric_quantize_inputs);
     for (int b = 0; b < batch_size; ++b) {
-      scaling_factors_ptr[b] *= weights_feature_scale;
+      scaling_factors[b] *= weights_feature_scale;
     }
 
     // Compute conv1d(inputs, weights_feature).
     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
-        weights_feature_ptr, num_filters, input_size, quantized_input_ptr,
-        scaling_factors_ptr, batch_size, scratch_ptr,
-        /*per_channel_scale=*/nullptr, zero_points_ptr,
-        reinterpret_cast<int32_t*>(scratch_ptr), row_sums_ptr, compute_row_sums,
+        weights_feature_data, num_filters, input_size, quantized_input,
+        scaling_factors, batch_size, scratch,
+        /*per_channel_scale=*/nullptr, zero_points,
+        reinterpret_cast<int32_t*>(scratch), row_sums, compute_row_sums,
         /*context=*/nullptr);
   }
   // Copy the latest activation from scratch into activation_state:
   // The last, i.e. (memory_size-1)th entry for each batch, and filter.
   for (int i = 0; i < batch_size * num_filters; ++i) {
-    state_ptr[i * memory_size + memory_size - 1] = scratch_ptr[i];
+    state[i * memory_size + memory_size - 1] = scratch[i];
   }
 
   // TODO(alanchiao): can optimize hybrid case ~5% by unrolling loop in applying
   // time weights so that the inner loop multiplies eight elements at a time.
   ApplyTimeWeightsBiasAndActivation(
-      batch_size, memory_size, num_filters, num_units, rank, weights_time_ptr,
-      bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr);
+      batch_size, memory_size, num_filters, num_units, rank, weights_time_data,
+      bias_data, params->activation, state, scratch, output_data);
 }
 
 }  // namespace reference_ops
diff --git a/tensorflow/lite/kernels/svdf.cc b/tensorflow/lite/kernels/svdf.cc
index f3a17e1d3b5..8f5c9a86bff 100644
--- a/tensorflow/lite/kernels/svdf.cc
+++ b/tensorflow/lite/kernels/svdf.cc
@@ -304,9 +304,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 
   switch (weights_feature->type) {
     case kTfLiteFloat32: {
-      reference_ops::EvalFloatSVDF(context, node, input, weights_feature,
-                                   weights_time, bias, params, scratch, state,
-                                   output);
+      reference_ops::EvalFloatSVDF(
+          params, GetTensorShape(input), GetTensorData<float>(input),
+          GetTensorShape(weights_feature),
+          GetTensorData<float>(weights_feature), GetTensorShape(weights_time),
+          GetTensorData<float>(weights_time), GetTensorShape(bias),
+          GetTensorData<float>(bias), GetTensorData<float>(scratch),
+          GetTensorData<float>(state), GetTensorShape(output),
+          GetTensorData<float>(output));
       return kTfLiteOk;
       break;
     }
@@ -346,10 +351,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
           op_data->float_weights_time_initialized = true;
         }
 
+        int32_t* zero_points_ptr = nullptr;
+        int32_t* row_sums_ptr = nullptr;
+        if (params->asymmetric_quantize_inputs && row_sums != nullptr) {
+          zero_points_ptr = GetTensorData<int32_t>(zero_points);
+          row_sums_ptr = GetTensorData<int32_t>(row_sums);
+        }
+
         reference_ops::EvalHybridSVDF(
-            context, node, input, weights_feature, float_weights_time, bias,
-            params, scratch, scaling_factors, input_quantized, state, output,
-            zero_points, row_sums, &op_data->compute_row_sums);
+            params, GetTensorShape(input), GetTensorData<float>(input),
+            GetTensorShape(weights_feature),
+            GetTensorData<int8_t>(weights_feature),
+            weights_feature->params.scale, GetTensorShape(float_weights_time),
+            GetTensorData<float>(float_weights_time), GetTensorShape(bias),
+            GetTensorData<float>(bias), GetTensorData<float>(scratch),
+            GetTensorData<float>(scaling_factors),
+            GetTensorData<int8_t>(input_quantized), GetTensorData<float>(state),
+            GetTensorShape(output), GetTensorData<float>(output),
+            zero_points_ptr, row_sums_ptr, &op_data->compute_row_sums);
         return kTfLiteOk;
       } else {
         auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
@@ -363,9 +382,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
         // Currently supports only ReLU.
         // TODO(jianlijianli): support other activations.
         TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
+
         reference_ops::EvalIntegerSVDF(
-            context, node, input, weights_feature, weights_time, bias, params,
-            state, output, scratch, output_temp, op_data->effective_scale_1_a,
+            params, GetTensorShape(input), GetTensorData<int8_t>(input),
+            GetTensorShape(weights_feature),
+            GetTensorData<int8_t>(weights_feature),
+            GetTensorShape(weights_time), GetTensorData<int16_t>(weights_time),
+            GetTensorShape(bias), GetTensorData<int32_t>(bias),
+            GetTensorData<int16_t>(state), GetTensorShape(output),
+            GetTensorData<int8_t>(output), GetTensorData<int32_t>(scratch),
+            GetTensorData<int32_t>(output_temp), op_data->effective_scale_1_a,
             op_data->effective_scale_1_b, op_data->effective_scale_2_a,
             op_data->effective_scale_2_b, input_params->zero_point->data[0],
             output_params->zero_point->data[0]);