From be369f57e9e46d03ccd62f1031f9dc484c1016de Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Mon, 27 Jan 2020 15:31:55 -0800 Subject: [PATCH] Dispatch to CpuBackendGemm for Hybrid Ops. PiperOrigin-RevId: 291814226 Change-Id: Ie7671d134082cf88e4c75f3c7da69bc2c4ae9fe0 --- .../kernels/bidirectional_sequence_lstm.cc | 36 +- tensorflow/lite/kernels/conv.cc | 37 ++- tensorflow/lite/kernels/fully_connected.cc | 31 +- tensorflow/lite/kernels/internal/BUILD | 1 + .../internal/optimized/legacy_optimized_ops.h | 10 +- .../internal/optimized/neon_tensor_utils.cc | 46 ++- .../internal/optimized/neon_tensor_utils.h | 11 + .../optimized/neon_tensor_utils_impl.h | 9 + .../internal/optimized/optimized_ops.h | 14 +- .../internal/optimized/sse_tensor_utils.h | 10 + .../reference/portable_tensor_utils.h | 10 + .../lite/kernels/internal/tensor_utils.h | 9 + .../kernels/internal/tensor_utils_test.cc | 15 + tensorflow/lite/kernels/lstm.cc | 26 +- tensorflow/lite/kernels/lstm_eval.cc | 310 ++++++++++-------- tensorflow/lite/kernels/lstm_eval.h | 3 +- .../kernels/unidirectional_sequence_lstm.cc | 25 +- 17 files changed, 435 insertions(+), 168 deletions(-) diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index 4b2b582877b..33c43aacbc7 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -137,8 +138,9 @@ enum TemporaryTensor { kScalingFactors = 7, kProductScalingFactors = 8, kRecoveredCellWeights = 9, - kAuxInputQuantized = 10, // Optional, quantized tensor for auxiliary input. - kNumTemporaryTensors = 11 + kAccumScratchBuffer = 10, + kAuxInputQuantized = 11, // Optional, quantized tensor for auxiliary input. + kNumTemporaryTensors }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -726,6 +728,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { recovered_cell_weights_size)); } + // Allocate a temporary tensor to store the accumulated int32 values. + node->temporaries->data[kAccumScratchBuffer] = + *scratch_tensor_index + kAccumScratchBuffer; + TfLiteTensor* accum_scratch = + GetTemporary(context, node, kAccumScratchBuffer); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int n_cell = std::max(n_fw_cell, n_bw_cell); + if (has_aux_input) { + n_cell = std::max(n_cell, fw_aux_input_to_output_weights->dims->data[0]); + n_cell = std::max(n_cell, bw_aux_input_to_output_weights->dims->data[0]); + } + int accum_scratch_dims[2] = {n_cell, n_batch}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2); + accum_size->data[0] = n_cell; + accum_size->data[1] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, accum_scratch, accum_size)); + } + // Only allocate a temporary tensor for quantized auxiliary input if we are // actually going to use it. if (has_aux_input) { @@ -977,6 +1001,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* aux_input_quantized = use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) : nullptr; + TfLiteTensor* accum_scratch = + GetTemporary(context, node, kAccumScratchBuffer); TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( input, fw_input_to_input_weights, fw_input_to_forget_weights, @@ -998,7 +1024,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, aux_input_quantized, fw_activation_state_quantized, fw_cell_state_quantized, - fw_activation_state, fw_cell_state, fw_output); + fw_activation_state, fw_cell_state, accum_scratch, fw_output, + CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, fw_pass_status); TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid( @@ -1021,7 +1048,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, aux_input_quantized, bw_activation_state_quantized, bw_cell_state_quantized, - bw_activation_state, bw_cell_state, actual_bw_output); + bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output, + CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; } diff --git a/tensorflow/lite/kernels/conv.cc b/tensorflow/lite/kernels/conv.cc index fc2541e93e0..a9724f5ec9d 100644 --- a/tensorflow/lite/kernels/conv.cc +++ b/tensorflow/lite/kernels/conv.cc @@ -71,6 +71,7 @@ struct OpData { int input_quantized_id = kTensorNotAllocated; int scaling_factors_id = kTensorNotAllocated; int input_offset_id = kTensorNotAllocated; + int accum_scratch_id = kTensorNotAllocated; TfLitePaddingValues padding; // The scaling factor from input to output (aka the 'real multiplier') can @@ -92,6 +93,7 @@ struct OpData { int32_t hwcn_weights_index; int32_t input_quantized_index; int32_t scaling_factors_index; + int32_t accum_scratch_index; int32_t input_offset_index; bool need_hwcn_weights = false; @@ -262,6 +264,13 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, } ++temporaries_count; + // Allocate tensor to store the accumulators for the matrix multiply. + data->accum_scratch_index = temporaries_count; + if (data->accum_scratch_id == kTensorNotAllocated) { + TF_LITE_ENSURE_OK( + context, context->AddTensors(context, 1, &data->accum_scratch_id)); + } + ++temporaries_count; if (is_per_channel) { data->input_offset_index = temporaries_count; if (data->input_offset_id == kTensorNotAllocated) { @@ -485,6 +494,21 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context, scaling_factors_size)); } + node->temporaries->data[data->accum_scratch_index] = data->accum_scratch_id; + TfLiteTensor* accum_scratch = + GetTemporary(context, node, data->accum_scratch_index); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {channels_out, batches}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); + accum_scratch_size->data[0] = channels_out; + accum_scratch_size->data[1] = batches; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, + accum_scratch_size)); + } + if (is_hybrid_per_channel) { const auto* affine_quantization = reinterpret_cast( @@ -771,7 +795,7 @@ template void EvalHybrid(TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params, OpData* data, TfLiteTensor* input, TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, - TfLiteTensor* output) { + TfLiteTensor* accum_scratch, TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRange(params->activation, &output_activation_min, &output_activation_max); @@ -816,9 +840,11 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node, op_params, scaling_factors_ptr, GetTensorShape(input), quantized_input_ptr_batch, GetTensorShape(filter), GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), + GetTensorData(bias), GetTensorShape(accum_scratch), + GetTensorData(accum_scratch), GetTensorShape(output), GetTensorData(output), GetTensorShape(im2col), - GetTensorData(im2col)); + GetTensorData(im2col), + CpuBackendContext::GetFromContext(context)); break; } } @@ -859,8 +885,11 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) { EvalHybridPerChannel(context, node, params, data, input, filter, bias, im2col, output); } else { + TfLiteTensor* accum_scratch = + &context->tensors[node->temporaries + ->data[data->accum_scratch_index]]; EvalHybrid(context, node, params, data, input, filter, - bias, im2col, output); + bias, im2col, accum_scratch, output); } } else { EvalFloat(context, node, params, data, input, filter, bias, diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 733001c1e33..36dab796a28 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -116,7 +116,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Instead, we allocate a new object to carry information from Prepare() to // Eval(). auto* op_data = new OpData(); - context->AddTensors(context, /*tensors_to_add=*/2, + context->AddTensors(context, /*tensors_to_add=*/3, &op_data->scratch_tensor_index); return op_data; } @@ -183,10 +183,12 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { // If we have to perform on-the-fly quantization (with quantized weights and // float inputs) first we need to quantize the inputs. Allocate a temporary // buffer to store the intermediate quantized values. + // Additionally, we allocate a temporary buffer to store the accumulated + // quantized values prior to multiplication by the scaling factor. if (input->type == kTfLiteFloat32 && (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) { TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries = TfLiteIntArrayCreate(3); node->temporaries->data[0] = data->scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); @@ -201,6 +203,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); scaling_factors->type = kTfLiteFloat32; scaling_factors->allocation_type = kTfLiteArenaRw; + int scaling_dims[1] = {batch_size}; if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); @@ -208,6 +211,20 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, scaling_factors_size)); } + + node->temporaries->data[2] = data->scratch_tensor_index + 2; + TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {num_units, batch_size}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2); + accum_size->data[0] = num_units; + accum_size->data[1] = batch_size; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, accum_scratch, accum_size)); + } } // Resize output. @@ -341,11 +358,19 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, } // Compute output += weight * quantized_input +#ifdef TFLITE_WITH_RUY_GEMV + TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2); + int32_t* scratch = GetTensorData(accum_scratch); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + filter_data, num_units, input_size, quant_data, scaling_factors_ptr, + batch_size, scratch, GetTensorData(output), + /*result_stride=*/1, CpuBackendContext::GetFromContext(context)); +#else tensor_utils::MatrixBatchVectorMultiplyAccumulate( filter_data, num_units, input_size, quant_data, scaling_factors_ptr, batch_size, GetTensorData(output), /*result_stride=*/1); - +#endif // Apply activation function to floats. tensor_utils::ApplyActivationToVector( GetTensorData(output), batch_size * num_units, params->activation, diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 0a571317809..737a8c1df1b 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -639,6 +639,7 @@ cc_library( ":neon_tensor_utils", ":portable_tensor_utils", "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:cpu_backend_context", "//tensorflow/lite/kernels:op_macros", ], ) diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index b389f493413..da612804253 100644 --- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "public/gemmlowp.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h" #include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h" @@ -2551,8 +2552,10 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, float* scaling_factors_ptr, float output_activation_min, float output_activation_max, + int32_t* scratch_data, const Dims<4>& scratch_dims, float* output_data, const Dims<4>& output_dims, - int8_t* im2col_data, const Dims<4>& im2col_dims) { + int8_t* im2col_data, const Dims<4>& im2col_dims, + CpuBackendContext* context) { tflite::ConvParams op_params; // Padding type is ignored, but still set. op_params.padding_type = PaddingType::kSame; @@ -2565,8 +2568,9 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims, HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims), input_data, DimsToShape(filter_dims), filter_data, - DimsToShape(bias_dims), bias_data, DimsToShape(output_dims), - output_data, DimsToShape(im2col_dims), im2col_data); + DimsToShape(bias_dims), bias_data, DimsToShape(scratch_dims), + scratch_data, DimsToShape(output_dims), output_data, + DimsToShape(im2col_dims), im2col_data, context); } template diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 1518d95826f..eefe4bab0ca 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -974,7 +974,9 @@ void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias, dst_params.cols = n_batch; GemmParams gemm_params; - gemm_params.bias = bias; + if (bias) { + gemm_params.bias = bias; + } cpu_backend_gemm::Gemm(lhs_params, input_to_gate_weights, rhs_params, input, dst_params, scratch, gemm_params, context); } @@ -1145,6 +1147,48 @@ void NeonMatrixBatchVectorMultiplyAccumulate( free(aligned_vec_free); } +void NeonMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, int32_t* scratch, float* __restrict__ result, + int result_stride, CpuBackendContext* context) { + if (m_rows % 4 == 0 && result_stride == 1) { + const int32_t* bias = static_cast(nullptr); + NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, + /*output_zp =*/0, scratch, context); + + // Multiply by float scaling factors and write to result + const int total_size = n_batch * m_rows; + int i = 0; + for (; i <= total_size - 8; i += 8, result += 8 * result_stride) { + const float batch_scaling_factor0 = scaling_factors[i / m_rows]; + const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows]; + const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0); + const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1); + const int32x4_t scratch_val0 = vld1q_s32(scratch + i); + const int32x4_t scratch_val1 = vld1q_s32(scratch + i + 4); + const float32x4_t float_val0 = vcvtq_f32_s32(scratch_val0); + const float32x4_t float_val1 = vcvtq_f32_s32(scratch_val1); + const float32x4_t result0 = + vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0); + const float32x4_t result1 = vmlaq_f32( + vld1q_f32(result + 4 * result_stride), float_val1, scaling_factor1); + vst1q_f32(result, result0); + vst1q_f32(result + 4 * result_stride, result1); + } + scratch += i; + for (; i < total_size; i++, result += result_stride) { + const float batch_scaling_factor = scaling_factors[i / m_rows]; + int32_t x = *(scratch++); + *result += x * batch_scaling_factor; + } + return; + } + NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, + scaling_factors, n_batch, result, + result_stride); +} + void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar, int32_t n_row, int32_t n_col, int32_t* output) { diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h index 571d3ff108f..996d9b1b2bd 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -18,6 +18,7 @@ limitations under the License. // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h" #include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h" @@ -41,6 +42,16 @@ void MatrixBatchVectorMultiplyAccumulate( vectors, scaling_factors, n_batch, result, result_stride); } +void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, int32_t* scratch, float* __restrict__ result, + int result_stride, CpuBackendContext* context) { + NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, + vectors, scaling_factors, n_batch, scratch, result, + result_stride, context); +} + void MatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, const int8_t* __restrict__ vectors, const float* scaling_factors, diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h index 8e604d9b33e..bfeb8e628a9 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h @@ -18,6 +18,7 @@ limitations under the License. // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" #if defined(_MSC_VER) @@ -42,6 +43,14 @@ void NeonMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ vectors, const float* scaling_factors, int n_batch, float* __restrict__ result, int result_stride); +// Same as above but with a scratch buffer and CpuBackendContext for the +// int8 x int8 -> int32 accumulation computation +void NeonMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, int32_t* scratch, float* __restrict__ result, + int result_stride, CpuBackendContext* context); + // Matrix multiplication for quantized values using asymmetric quantization. void NeonMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 537cfee1e31..331d89d7fb9 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -1246,8 +1246,10 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr, const RuntimeShape& filter_shape, const int8_t* filter_data, const RuntimeShape& bias_shape, const float* bias_data, - const RuntimeShape& output_shape, float* output_data, - const RuntimeShape& im2col_shape, int8_t* im2col_data) { + const RuntimeShape& accum_scratch_shape, + int32_t* accum_scratch, const RuntimeShape& output_shape, + float* output_data, const RuntimeShape& im2col_shape, + int8_t* im2col_data, CpuBackendContext* context) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const float output_activation_min = params.float_activation_min; @@ -1310,11 +1312,17 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr, std::fill_n(output_data, output_rows * output_cols, 0.0f); +#ifdef TFLITE_WITH_RUY_GEMV + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + filter_data, filter_rows, filter_cols, gemm_input_data, + scaling_factors_ptr, /*n_batch=*/gemm_input_rows, accum_scratch, + output_data, /*result_stride=*/1, context); +#else tensor_utils::MatrixBatchVectorMultiplyAccumulate( filter_data, filter_rows, filter_cols, gemm_input_data, scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data, /*result_stride=*/1); - +#endif AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max, bias_shape, bias_data, output_shape, output_data); diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h index 9ceaa2760da..0fc1a2d453d 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h @@ -27,6 +27,7 @@ limitations under the License. // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/optimized/neon_check.h" #include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h" #include "tensorflow/lite/kernels/internal/optimized/sse_check.h" @@ -52,6 +53,15 @@ void MatrixBatchVectorMultiplyAccumulate( vectors, scaling_factors, n_batch, result, result_stride); } +void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, int32_t* scratch, float* __restrict__ result, + int result_stride, CpuBackendContext* context) { + SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, + vectors, scaling_factors, n_batch, result, result_stride); +} + void MatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, const int8_t* __restrict__ vectors, const float* scaling_factors, diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h index b3f7c0834ca..4ac48c8e7af 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h @@ -76,6 +76,16 @@ void MatrixBatchVectorMultiplyAccumulate( result_stride); } +void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vector, const float* scaling_factors, + int n_batch, int32_t* scratch, float* __restrict__ result, + int result_stride, CpuBackendContext* context) { + PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector, + scaling_factors, n_batch, result, + result_stride); +} + void MatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, const int8_t* __restrict__ vectors, const float* scaling_factors, diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h index 62fe08ba7c0..c90b2588fdc 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/lite/kernels/internal/tensor_utils.h @@ -100,6 +100,15 @@ void MatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ vectors, const float* scaling_factors, int n_batch, float* __restrict__ result, int result_stride); +// Same as the function above, but provide a scratch buffer for the +// int8 x int8 -> int32 and a CpuBackendContext for the accumulator +// computation. +void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, int32_t* scratch, float* __restrict__ result, + int result_stride, CpuBackendContext* context); + // Same as the function above except that vector values // are quantized with asymmetric quantization per-batch and the matrix // is quantized per row. diff --git a/tensorflow/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc index e6b76ee19a9..f1ea1afd681 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc @@ -1152,6 +1152,21 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { EXPECT_NEAR(expected_c_float_data[i], c_float_data[i], 0.001); } + // Call version of MatrixBatchVectorMultiplyAccumulate that uses + // CpuBackendGemm. + std::vector accum_scratch(a_rows * batches); + std::vector c_float_data_2(a_rows * batches, 0.0); + CpuBackendContext context; + MatrixBatchVectorMultiplyAccumulate( + a_int8_data, a_rows, a_cols, b_int8_data, scaling_factor_c, batches, + accum_scratch.data(), c_float_data_2.data(), + /*result_stride=*/1, &context); + + // Assert (again) we obtain the expected recovered float values. + for (int i = 0; i < a_rows * b_cols * batches; ++i) { + EXPECT_NEAR(expected_c_float_data[i], c_float_data_2[i], 0.001); + } + aligned_free(a_int8_data); } #endif // __ANDROID__ diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index b5824a6fb81..c97745f77e8 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -371,7 +371,7 @@ TfLiteStatus PopulateQuantizedLstmParams( void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); op_data->kernel_type = kTfLiteLSTMFullKernel; - context->AddTensors(context, /*tensors_to_add=*/7, + context->AddTensors(context, /*tensors_to_add=*/8, &op_data->scratch_tensor_index); return op_data; } @@ -871,7 +871,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayFree(node->temporaries); if (is_hybrid_op) { - node->temporaries = TfLiteIntArrayCreate(7); + node->temporaries = TfLiteIntArrayCreate(8); } else if (is_integer) { node->temporaries = TfLiteIntArrayCreate(6); } else { @@ -940,7 +940,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, cell_state_quantized, cell_state_quantized_size)); } - // Allocate temporary tensors to store scaling factors and product scaling // factors. The latter is a convenience storage which allows to quantize // a vector once (which produces the scaling factors) and multiply it with @@ -987,6 +986,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, recovered_cell_weights, recovered_cell_weights_size)); } + // Allocate a temporary tensor to store accumulate values for matrix + // multiplication before multiplication by scaling factor + node->temporaries->data[7] = op_data->scratch_tensor_index + 7; + TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/7); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {n_cell, n_batch}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2); + accum_size->data[0] = n_cell; + accum_size->data[1] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, accum_scratch, accum_size)); + } } if (is_integer) { @@ -1135,6 +1149,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/5); TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, /*index=*/6); + TfLiteTensor* output_scratch_buffer = + GetTemporary(context, node, /*index=*/7); return lstm_eval::EvalHybrid( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, @@ -1155,7 +1171,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, - cell_state_quantized, activation_state, cell_state, output); + cell_state_quantized, activation_state, cell_state, + output_scratch_buffer, output, + CpuBackendContext::GetFromContext(context)); } else { TfLiteTensor* scratch0 = GetTemporary(context, node, /*index=*/0); TfLiteTensor* scratch1 = GetTemporary(context, node, /*index=*/1); diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index b50a13a6bd6..869fd9abf49 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -37,6 +37,23 @@ inline float GetTensorScale(const TfLiteTensor* tensor) { return tensor == nullptr ? 1.0f : tensor->params.scale; } +inline void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, int32_t* scratch, float* __restrict__ result, + int result_stride, CpuBackendContext* context) { +// TODO(b/148289189) Remove when Ruy GEMV is the default. +#ifdef TFLITE_WITH_RUY_GEMV + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, scratch, + result, result_stride, context); +#else + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, + result_stride); +#endif +} + // Performs an LSTM batch inference step for input specified by input_ptr. // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and // biases (*_bias_ptr), and buffers (*_scratch), along with additional @@ -461,7 +478,8 @@ inline void LstmStepHybrid( float* recovered_cell_weights, int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr, - float* cell_state_ptr, float* output_ptr) { + float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr, + CpuBackendContext* context) { ruy::profiler::ScopeLabel label("LstmStepHybrid"); // Since we have already checked that weights are all there or none, we // can check the existence of only one to the get the condition. @@ -506,37 +524,41 @@ inline void LstmStepHybrid( product_scaling_factors[b] = scaling_factors[b] * input_to_input_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, input_gate_scratch, - /*result_stride=*/1); + MatrixBatchVectorMultiplyAccumulate(input_to_input_weights_ptr, n_cell, + n_input, quantized_input_ptr, + product_scaling_factors, n_batch, + accum_scratch_ptr, input_gate_scratch, + /*result_stride=*/1, context); } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_forget_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, forget_gate_scratch, - /*result_stride=*/1); + MatrixBatchVectorMultiplyAccumulate(input_to_forget_weights_ptr, n_cell, + n_input, quantized_input_ptr, + product_scaling_factors, n_batch, + accum_scratch_ptr, forget_gate_scratch, + /*result_stride=*/1, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_cell_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); + product_scaling_factors, n_batch, accum_scratch_ptr, cell_scratch, + /*result_stride=*/1, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_output_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, output_gate_scratch, - /*result_stride=*/1); + MatrixBatchVectorMultiplyAccumulate(input_to_output_weights_ptr, n_cell, + n_input, quantized_input_ptr, + product_scaling_factors, n_batch, + accum_scratch_ptr, output_gate_scratch, + /*result_stride=*/1, context); } // For each batch and cell: compute aux_input_weight * aux_input. @@ -555,38 +577,38 @@ inline void LstmStepHybrid( product_scaling_factors[b] = scaling_factors[b] * aux_input_to_input_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( aux_input_to_input_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); + accum_scratch_ptr, input_gate_scratch, /*result_stride=*/1, context); } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_forget_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( aux_input_to_forget_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, product_scaling_factors, n_batch, - forget_gate_scratch, /*result_stride=*/1); + accum_scratch_ptr, forget_gate_scratch, /*result_stride=*/1, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_cell_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( aux_input_to_cell_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch, - /*result_stride=*/1); + quantized_aux_input_ptr, product_scaling_factors, n_batch, + accum_scratch_ptr, cell_scratch, /*result_stride=*/1, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_output_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( aux_input_to_output_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, product_scaling_factors, n_batch, - output_gate_scratch, /*result_stride=*/1); + accum_scratch_ptr, output_gate_scratch, /*result_stride=*/1, context); } if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { @@ -605,38 +627,38 @@ inline void LstmStepHybrid( product_scaling_factors[b] = scaling_factors[b] * recurrent_to_input_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); + accum_scratch_ptr, input_gate_scratch, /*result_stride=*/1, context); } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * recurrent_to_forget_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - forget_gate_scratch, /*result_stride=*/1); + accum_scratch_ptr, forget_gate_scratch, /*result_stride=*/1, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * recurrent_to_cell_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( recurrent_to_cell_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - cell_scratch, /*result_stride=*/1); + accum_scratch_ptr, cell_scratch, /*result_stride=*/1, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * recurrent_to_output_weights_scale; } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( recurrent_to_output_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - output_gate_scratch, /*result_stride=*/1); + accum_scratch_ptr, output_gate_scratch, /*result_stride=*/1, context); } // For each batch and cell: update input gate. @@ -768,11 +790,12 @@ inline void LstmStepHybrid( scaling_factors[b] * projection_weights_scale; } for (int b = 0; b < n_batch; b++) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( + MatrixBatchVectorMultiplyAccumulate( projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b], - /*n_batch=*/1, output_ptr + b * output_batch_leading_dim, - /*result_stride=*/1); + /*n_batch=*/1, accum_scratch_ptr, + output_ptr + b * output_batch_leading_dim, + /*result_stride=*/1, context); } } if (params->proj_clip > 0.0) { @@ -1335,7 +1358,8 @@ TfLiteStatus EvalHybrid( TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, - TfLiteTensor* cell_state, TfLiteTensor* output) { + TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, + TfLiteTensor* output, CpuBackendContext* context) { TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); const int n_input = input->dims->data[input->dims->size - 1]; int max_time, n_batch; @@ -1389,60 +1413,59 @@ TfLiteStatus EvalHybrid( } float* output_ptr = GetTensorData(output) + t_rel * output_step + output_offset; - - LstmStepHybrid(input_ptr, GetTensorData(input_to_input_weights), - GetTensorScale(input_to_input_weights), - GetTensorData(input_to_forget_weights), - GetTensorScale(input_to_forget_weights), - GetTensorData(input_to_cell_weights), - GetTensorScale(input_to_cell_weights), - GetTensorData(input_to_output_weights), - GetTensorScale(input_to_output_weights), aux_input_ptr, - GetTensorData(aux_input_to_input_weights), - GetTensorScale(aux_input_to_input_weights), - GetTensorData(aux_input_to_forget_weights), - GetTensorScale(aux_input_to_forget_weights), - GetTensorData(aux_input_to_cell_weights), - GetTensorScale(aux_input_to_cell_weights), - GetTensorData(aux_input_to_output_weights), - GetTensorScale(aux_input_to_output_weights), - GetTensorData(recurrent_to_input_weights), - GetTensorScale(recurrent_to_input_weights), - GetTensorData(recurrent_to_forget_weights), - GetTensorScale(recurrent_to_forget_weights), - GetTensorData(recurrent_to_cell_weights), - GetTensorScale(recurrent_to_cell_weights), - GetTensorData(recurrent_to_output_weights), - GetTensorScale(recurrent_to_output_weights), - GetTensorData(cell_to_input_weights), - GetTensorScale(cell_to_input_weights), - GetTensorData(cell_to_forget_weights), - GetTensorScale(cell_to_forget_weights), - GetTensorData(cell_to_output_weights), - GetTensorScale(cell_to_output_weights), - GetTensorData(input_layer_norm_coefficients), - GetTensorData(forget_layer_norm_coefficients), - GetTensorData(cell_layer_norm_coefficients), - GetTensorData(output_layer_norm_coefficients), - GetTensorData(input_gate_bias), - GetTensorData(forget_gate_bias), - GetTensorData(cell_bias), - GetTensorData(output_gate_bias), - GetTensorData(projection_weights), - GetTensorScale(projection_weights), - GetTensorData(projection_bias), params, n_batch, - n_cell, n_input, aux_input_size, n_output, - output_batch_leading_dim, input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, - GetTensorData(scaling_factors), - GetTensorData(prod_scaling_factors), - GetTensorData(recovered_cell_weights), - GetTensorData(input_quantized), - GetTensorData(aux_input_quantized), - GetTensorData(output_state_quantized), - GetTensorData(cell_state_quantized), - GetTensorData(output_state), - GetTensorData(cell_state), output_ptr); + LstmStepHybrid( + input_ptr, GetTensorData(input_to_input_weights), + GetTensorScale(input_to_input_weights), + GetTensorData(input_to_forget_weights), + GetTensorScale(input_to_forget_weights), + GetTensorData(input_to_cell_weights), + GetTensorScale(input_to_cell_weights), + GetTensorData(input_to_output_weights), + GetTensorScale(input_to_output_weights), aux_input_ptr, + GetTensorData(aux_input_to_input_weights), + GetTensorScale(aux_input_to_input_weights), + GetTensorData(aux_input_to_forget_weights), + GetTensorScale(aux_input_to_forget_weights), + GetTensorData(aux_input_to_cell_weights), + GetTensorScale(aux_input_to_cell_weights), + GetTensorData(aux_input_to_output_weights), + GetTensorScale(aux_input_to_output_weights), + GetTensorData(recurrent_to_input_weights), + GetTensorScale(recurrent_to_input_weights), + GetTensorData(recurrent_to_forget_weights), + GetTensorScale(recurrent_to_forget_weights), + GetTensorData(recurrent_to_cell_weights), + GetTensorScale(recurrent_to_cell_weights), + GetTensorData(recurrent_to_output_weights), + GetTensorScale(recurrent_to_output_weights), + GetTensorData(cell_to_input_weights), + GetTensorScale(cell_to_input_weights), + GetTensorData(cell_to_forget_weights), + GetTensorScale(cell_to_forget_weights), + GetTensorData(cell_to_output_weights), + GetTensorScale(cell_to_output_weights), + GetTensorData(input_layer_norm_coefficients), + GetTensorData(forget_layer_norm_coefficients), + GetTensorData(cell_layer_norm_coefficients), + GetTensorData(output_layer_norm_coefficients), + GetTensorData(input_gate_bias), + GetTensorData(forget_gate_bias), + GetTensorData(cell_bias), + GetTensorData(output_gate_bias), + GetTensorData(projection_weights), + GetTensorScale(projection_weights), + GetTensorData(projection_bias), params, n_batch, n_cell, + n_input, aux_input_size, n_output, output_batch_leading_dim, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, GetTensorData(scaling_factors), + GetTensorData(prod_scaling_factors), + GetTensorData(recovered_cell_weights), + GetTensorData(input_quantized), + GetTensorData(aux_input_quantized), + GetTensorData(output_state_quantized), + GetTensorData(cell_state_quantized), + GetTensorData(output_state), GetTensorData(cell_state), + GetTensorData(output_scratch_buffer), output_ptr, context); } } else { for (int b = 0; b < n_batch; b++) { @@ -1474,59 +1497,60 @@ TfLiteStatus EvalHybrid( float* cell_scratch_ptr = cell_scratch + b * n_cell; float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell; - LstmStepHybrid(input_ptr, GetTensorData(input_to_input_weights), - GetTensorScale(input_to_input_weights), - GetTensorData(input_to_forget_weights), - GetTensorScale(input_to_forget_weights), - GetTensorData(input_to_cell_weights), - GetTensorScale(input_to_cell_weights), - GetTensorData(input_to_output_weights), - GetTensorScale(input_to_output_weights), aux_input_ptr, - GetTensorData(aux_input_to_input_weights), - GetTensorScale(aux_input_to_input_weights), - GetTensorData(aux_input_to_forget_weights), - GetTensorScale(aux_input_to_forget_weights), - GetTensorData(aux_input_to_cell_weights), - GetTensorScale(aux_input_to_cell_weights), - GetTensorData(aux_input_to_output_weights), - GetTensorScale(aux_input_to_output_weights), - GetTensorData(recurrent_to_input_weights), - GetTensorScale(recurrent_to_input_weights), - GetTensorData(recurrent_to_forget_weights), - GetTensorScale(recurrent_to_forget_weights), - GetTensorData(recurrent_to_cell_weights), - GetTensorScale(recurrent_to_cell_weights), - GetTensorData(recurrent_to_output_weights), - GetTensorScale(recurrent_to_output_weights), - GetTensorData(cell_to_input_weights), - GetTensorScale(cell_to_input_weights), - GetTensorData(cell_to_forget_weights), - GetTensorScale(cell_to_forget_weights), - GetTensorData(cell_to_output_weights), - GetTensorScale(cell_to_output_weights), - GetTensorData(input_layer_norm_coefficients), - GetTensorData(forget_layer_norm_coefficients), - GetTensorData(cell_layer_norm_coefficients), - GetTensorData(output_layer_norm_coefficients), - GetTensorData(input_gate_bias), - GetTensorData(forget_gate_bias), - GetTensorData(cell_bias), - GetTensorData(output_gate_bias), - GetTensorData(projection_weights), - GetTensorScale(projection_weights), - GetTensorData(projection_bias), params, - /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, - output_batch_leading_dim, input_gate_scratch_ptr, - forget_gate_scratch_ptr, cell_scratch_ptr, - output_gate_scratch_ptr, - GetTensorData(scaling_factors), - GetTensorData(prod_scaling_factors), - GetTensorData(recovered_cell_weights), - GetTensorData(input_quantized), - GetTensorData(aux_input_quantized), - GetTensorData(output_state_quantized), - GetTensorData(cell_state_quantized), - output_state_ptr, cell_state_ptr, output_ptr); + LstmStepHybrid( + input_ptr, GetTensorData(input_to_input_weights), + GetTensorScale(input_to_input_weights), + GetTensorData(input_to_forget_weights), + GetTensorScale(input_to_forget_weights), + GetTensorData(input_to_cell_weights), + GetTensorScale(input_to_cell_weights), + GetTensorData(input_to_output_weights), + GetTensorScale(input_to_output_weights), aux_input_ptr, + GetTensorData(aux_input_to_input_weights), + GetTensorScale(aux_input_to_input_weights), + GetTensorData(aux_input_to_forget_weights), + GetTensorScale(aux_input_to_forget_weights), + GetTensorData(aux_input_to_cell_weights), + GetTensorScale(aux_input_to_cell_weights), + GetTensorData(aux_input_to_output_weights), + GetTensorScale(aux_input_to_output_weights), + GetTensorData(recurrent_to_input_weights), + GetTensorScale(recurrent_to_input_weights), + GetTensorData(recurrent_to_forget_weights), + GetTensorScale(recurrent_to_forget_weights), + GetTensorData(recurrent_to_cell_weights), + GetTensorScale(recurrent_to_cell_weights), + GetTensorData(recurrent_to_output_weights), + GetTensorScale(recurrent_to_output_weights), + GetTensorData(cell_to_input_weights), + GetTensorScale(cell_to_input_weights), + GetTensorData(cell_to_forget_weights), + GetTensorScale(cell_to_forget_weights), + GetTensorData(cell_to_output_weights), + GetTensorScale(cell_to_output_weights), + GetTensorData(input_layer_norm_coefficients), + GetTensorData(forget_layer_norm_coefficients), + GetTensorData(cell_layer_norm_coefficients), + GetTensorData(output_layer_norm_coefficients), + GetTensorData(input_gate_bias), + GetTensorData(forget_gate_bias), + GetTensorData(cell_bias), + GetTensorData(output_gate_bias), + GetTensorData(projection_weights), + GetTensorScale(projection_weights), + GetTensorData(projection_bias), params, + /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, + output_batch_leading_dim, input_gate_scratch_ptr, + forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr, + GetTensorData(scaling_factors), + GetTensorData(prod_scaling_factors), + GetTensorData(recovered_cell_weights), + GetTensorData(input_quantized), + GetTensorData(aux_input_quantized), + GetTensorData(output_state_quantized), + GetTensorData(cell_state_quantized), output_state_ptr, + cell_state_ptr, GetTensorData(output_scratch_buffer), + output_ptr, context); } } } diff --git a/tensorflow/lite/kernels/lstm_eval.h b/tensorflow/lite/kernels/lstm_eval.h index 34e000ac3d4..f0f9d2d38ec 100644 --- a/tensorflow/lite/kernels/lstm_eval.h +++ b/tensorflow/lite/kernels/lstm_eval.h @@ -154,7 +154,8 @@ TfLiteStatus EvalHybrid( TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, - TfLiteTensor* cell_state, TfLiteTensor* output); + TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, + TfLiteTensor* output, CpuBackendContext* context); TfLiteStatus EvalInteger( const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index 470f8aec42b..b49974da2e0 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -90,7 +91,8 @@ enum TemporaryTensor { kScalingFactors = 4, kProductScalingFactors = 5, kRecoveredCellWeights = 6, - kNumTemporaryTensors = 7 + kAccumScratch = 7, + kNumTemporaryTensors }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -497,6 +499,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, recovered_cell_weights, recovered_cell_weights_size)); } + + // Allocate a temporary tensor to store the accumulated int32 values. + node->temporaries->data[kAccumScratch] = + scratch_tensor_index + kAccumScratch; + TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {n_cell, n_batch}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2); + accum_size->data[0] = n_cell; + accum_size->data[1] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, accum_scratch, accum_size)); + } } return kTfLiteOk; } @@ -615,6 +633,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/5); TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, /*index=*/6); + TfLiteTensor* accum_scratch = + GetTemporary(context, node, /*index=*/kAccumScratch); return lstm_eval::EvalHybrid( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, @@ -633,7 +653,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*output_offset=*/0, scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, - cell_state_quantized, activation_state, cell_state, output); + cell_state_quantized, activation_state, cell_state, accum_scratch, + output, CpuBackendContext::GetFromContext(context)); } default: context->ReportError(context, "Type %d is not currently supported.",