Dispatch to CpuBackendGemm for Hybrid Ops.

PiperOrigin-RevId: 291814226
Change-Id: Ie7671d134082cf88e4c75f3c7da69bc2c4ae9fe0
This commit is contained in:
T.J. Alumbaugh 2020-01-27 15:31:55 -08:00 committed by TensorFlower Gardener
parent 6ff06bafbc
commit be369f57e9
17 changed files with 435 additions and 168 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -137,8 +138,9 @@ enum TemporaryTensor {
kScalingFactors = 7,
kProductScalingFactors = 8,
kRecoveredCellWeights = 9,
kAuxInputQuantized = 10, // Optional, quantized tensor for auxiliary input.
kNumTemporaryTensors = 11
kAccumScratchBuffer = 10,
kAuxInputQuantized = 11, // Optional, quantized tensor for auxiliary input.
kNumTemporaryTensors
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@ -726,6 +728,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
recovered_cell_weights_size));
}
// Allocate a temporary tensor to store the accumulated int32 values.
node->temporaries->data[kAccumScratchBuffer] =
*scratch_tensor_index + kAccumScratchBuffer;
TfLiteTensor* accum_scratch =
GetTemporary(context, node, kAccumScratchBuffer);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int n_cell = std::max(n_fw_cell, n_bw_cell);
if (has_aux_input) {
n_cell = std::max(n_cell, fw_aux_input_to_output_weights->dims->data[0]);
n_cell = std::max(n_cell, bw_aux_input_to_output_weights->dims->data[0]);
}
int accum_scratch_dims[2] = {n_cell, n_batch};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
accum_size->data[0] = n_cell;
accum_size->data[1] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size));
}
// Only allocate a temporary tensor for quantized auxiliary input if we are
// actually going to use it.
if (has_aux_input) {
@ -977,6 +1001,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* aux_input_quantized =
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
: nullptr;
TfLiteTensor* accum_scratch =
GetTemporary(context, node, kAccumScratchBuffer);
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
input, fw_input_to_input_weights, fw_input_to_forget_weights,
@ -998,7 +1024,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_scratch_buffer, scaling_factors, prod_scaling_factors,
recovered_cell_weights, input_quantized, aux_input_quantized,
fw_activation_state_quantized, fw_cell_state_quantized,
fw_activation_state, fw_cell_state, fw_output);
fw_activation_state, fw_cell_state, accum_scratch, fw_output,
CpuBackendContext::GetFromContext(context));
TF_LITE_ENSURE_OK(context, fw_pass_status);
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
@ -1021,7 +1048,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_scratch_buffer, scaling_factors, prod_scaling_factors,
recovered_cell_weights, input_quantized, aux_input_quantized,
bw_activation_state_quantized, bw_cell_state_quantized,
bw_activation_state, bw_cell_state, actual_bw_output);
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
CpuBackendContext::GetFromContext(context));
TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk;
}

View File

@ -71,6 +71,7 @@ struct OpData {
int input_quantized_id = kTensorNotAllocated;
int scaling_factors_id = kTensorNotAllocated;
int input_offset_id = kTensorNotAllocated;
int accum_scratch_id = kTensorNotAllocated;
TfLitePaddingValues padding;
// The scaling factor from input to output (aka the 'real multiplier') can
@ -92,6 +93,7 @@ struct OpData {
int32_t hwcn_weights_index;
int32_t input_quantized_index;
int32_t scaling_factors_index;
int32_t accum_scratch_index;
int32_t input_offset_index;
bool need_hwcn_weights = false;
@ -262,6 +264,13 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
}
++temporaries_count;
// Allocate tensor to store the accumulators for the matrix multiply.
data->accum_scratch_index = temporaries_count;
if (data->accum_scratch_id == kTensorNotAllocated) {
TF_LITE_ENSURE_OK(
context, context->AddTensors(context, 1, &data->accum_scratch_id));
}
++temporaries_count;
if (is_per_channel) {
data->input_offset_index = temporaries_count;
if (data->input_offset_id == kTensorNotAllocated) {
@ -485,6 +494,21 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
scaling_factors_size));
}
node->temporaries->data[data->accum_scratch_index] = data->accum_scratch_id;
TfLiteTensor* accum_scratch =
GetTemporary(context, node, data->accum_scratch_index);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {channels_out, batches};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
accum_scratch_size->data[0] = channels_out;
accum_scratch_size->data[1] = batches;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
accum_scratch_size));
}
if (is_hybrid_per_channel) {
const auto* affine_quantization =
reinterpret_cast<TfLiteAffineQuantization*>(
@ -771,7 +795,7 @@ template <KernelType kernel_type>
void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* output) {
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
@ -816,9 +840,11 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
op_params, scaling_factors_ptr, GetTensorShape(input),
quantized_input_ptr_batch, GetTensorShape(filter),
GetTensorData<int8_t>(filter), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(bias), GetTensorShape(accum_scratch),
GetTensorData<int32_t>(accum_scratch), GetTensorShape(output),
GetTensorData<float>(output), GetTensorShape(im2col),
GetTensorData<int8_t>(im2col));
GetTensorData<int8_t>(im2col),
CpuBackendContext::GetFromContext(context));
break;
}
}
@ -859,8 +885,11 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
EvalHybridPerChannel<kernel_type>(context, node, params, data, input,
filter, bias, im2col, output);
} else {
TfLiteTensor* accum_scratch =
&context->tensors[node->temporaries
->data[data->accum_scratch_index]];
EvalHybrid<kernel_type>(context, node, params, data, input, filter,
bias, im2col, output);
bias, im2col, accum_scratch, output);
}
} else {
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,

View File

@ -116,7 +116,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
auto* op_data = new OpData();
context->AddTensors(context, /*tensors_to_add=*/2,
context->AddTensors(context, /*tensors_to_add=*/3,
&op_data->scratch_tensor_index);
return op_data;
}
@ -183,10 +183,12 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
// If we have to perform on-the-fly quantization (with quantized weights and
// float inputs) first we need to quantize the inputs. Allocate a temporary
// buffer to store the intermediate quantized values.
// Additionally, we allocate a temporary buffer to store the accumulated
// quantized values prior to multiplication by the scaling factor.
if (input->type == kTfLiteFloat32 &&
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(2);
node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = data->scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
@ -201,6 +203,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {batch_size};
if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
@ -208,6 +211,20 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size));
}
node->temporaries->data[2] = data->scratch_tensor_index + 2;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {num_units, batch_size};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
accum_size->data[0] = num_units;
accum_size->data[1] = batch_size;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size));
}
}
// Resize output.
@ -341,11 +358,19 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
}
// Compute output += weight * quantized_input
#ifdef TFLITE_WITH_RUY_GEMV
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
batch_size, scratch, GetTensorData<float>(output),
/*result_stride=*/1, CpuBackendContext::GetFromContext(context));
#else
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
batch_size, GetTensorData<float>(output),
/*result_stride=*/1);
#endif
// Apply activation function to floats.
tensor_utils::ApplyActivationToVector(
GetTensorData<float>(output), batch_size * num_units, params->activation,

View File

@ -639,6 +639,7 @@ cc_library(
":neon_tensor_utils",
":portable_tensor_utils",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:cpu_backend_context",
"//tensorflow/lite/kernels:op_macros",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <sys/types.h>
#include "public/gemmlowp.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h"
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
@ -2551,8 +2552,10 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
int stride_width, int stride_height, int pad_width,
int pad_height, float* scaling_factors_ptr,
float output_activation_min, float output_activation_max,
int32_t* scratch_data, const Dims<4>& scratch_dims,
float* output_data, const Dims<4>& output_dims,
int8_t* im2col_data, const Dims<4>& im2col_dims) {
int8_t* im2col_data, const Dims<4>& im2col_dims,
CpuBackendContext* context) {
tflite::ConvParams op_params;
// Padding type is ignored, but still set.
op_params.padding_type = PaddingType::kSame;
@ -2565,8 +2568,9 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
input_data, DimsToShape(filter_dims), filter_data,
DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
output_data, DimsToShape(im2col_dims), im2col_data);
DimsToShape(bias_dims), bias_data, DimsToShape(scratch_dims),
scratch_data, DimsToShape(output_dims), output_data,
DimsToShape(im2col_dims), im2col_data, context);
}
template <FusedActivationFunctionType Ac>

View File

@ -974,7 +974,9 @@ void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias,
dst_params.cols = n_batch;
GemmParams<int32, int32> gemm_params;
gemm_params.bias = bias;
if (bias) {
gemm_params.bias = bias;
}
cpu_backend_gemm::Gemm(lhs_params, input_to_gate_weights, rhs_params, input,
dst_params, scratch, gemm_params, context);
}
@ -1145,6 +1147,48 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
free(aligned_vec_free);
}
void NeonMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
int result_stride, CpuBackendContext* context) {
if (m_rows % 4 == 0 && result_stride == 1) {
const int32_t* bias = static_cast<const int32_t*>(nullptr);
NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
/*output_zp =*/0, scratch, context);
// Multiply by float scaling factors and write to result
const int total_size = n_batch * m_rows;
int i = 0;
for (; i <= total_size - 8; i += 8, result += 8 * result_stride) {
const float batch_scaling_factor0 = scaling_factors[i / m_rows];
const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
const int32x4_t scratch_val0 = vld1q_s32(scratch + i);
const int32x4_t scratch_val1 = vld1q_s32(scratch + i + 4);
const float32x4_t float_val0 = vcvtq_f32_s32(scratch_val0);
const float32x4_t float_val1 = vcvtq_f32_s32(scratch_val1);
const float32x4_t result0 =
vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
const float32x4_t result1 = vmlaq_f32(
vld1q_f32(result + 4 * result_stride), float_val1, scaling_factor1);
vst1q_f32(result, result0);
vst1q_f32(result + 4 * result_stride, result1);
}
scratch += i;
for (; i < total_size; i++, result += result_stride) {
const float batch_scaling_factor = scaling_factors[i / m_rows];
int32_t x = *(scratch++);
*result += x * batch_scaling_factor;
}
return;
}
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
scaling_factors, n_batch, result,
result_stride);
}
void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
int32_t n_row, int32_t n_col,
int32_t* output) {

View File

@ -18,6 +18,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
@ -41,6 +42,16 @@ void MatrixBatchVectorMultiplyAccumulate(
vectors, scaling_factors, n_batch, result, result_stride);
}
void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
int result_stride, CpuBackendContext* context) {
NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
vectors, scaling_factors, n_batch, scratch, result,
result_stride, context);
}
void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,

View File

@ -18,6 +18,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
#if defined(_MSC_VER)
@ -42,6 +43,14 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result, int result_stride);
// Same as above but with a scratch buffer and CpuBackendContext for the
// int8 x int8 -> int32 accumulation computation
void NeonMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
int result_stride, CpuBackendContext* context);
// Matrix multiplication for quantized values using asymmetric quantization.
void NeonMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,

View File

@ -1246,8 +1246,10 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
const RuntimeShape& filter_shape,
const int8_t* filter_data,
const RuntimeShape& bias_shape, const float* bias_data,
const RuntimeShape& output_shape, float* output_data,
const RuntimeShape& im2col_shape, int8_t* im2col_data) {
const RuntimeShape& accum_scratch_shape,
int32_t* accum_scratch, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& im2col_shape,
int8_t* im2col_data, CpuBackendContext* context) {
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
const float output_activation_min = params.float_activation_min;
@ -1310,11 +1312,17 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
std::fill_n(output_data, output_rows * output_cols, 0.0f);
#ifdef TFLITE_WITH_RUY_GEMV
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, filter_rows, filter_cols, gemm_input_data,
scaling_factors_ptr, /*n_batch=*/gemm_input_rows, accum_scratch,
output_data, /*result_stride=*/1, context);
#else
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, filter_rows, filter_cols, gemm_input_data,
scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
/*result_stride=*/1);
#endif
AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
bias_shape, bias_data, output_shape,
output_data);

View File

@ -27,6 +27,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
#include "tensorflow/lite/kernels/internal/optimized/sse_check.h"
@ -52,6 +53,15 @@ void MatrixBatchVectorMultiplyAccumulate(
vectors, scaling_factors, n_batch, result, result_stride);
}
void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
int result_stride, CpuBackendContext* context) {
SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
vectors, scaling_factors, n_batch, result, result_stride);
}
void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,

View File

@ -76,6 +76,16 @@ void MatrixBatchVectorMultiplyAccumulate(
result_stride);
}
void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vector, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
int result_stride, CpuBackendContext* context) {
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
scaling_factors, n_batch, result,
result_stride);
}
void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,

View File

@ -100,6 +100,15 @@ void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result, int result_stride);
// Same as the function above, but provide a scratch buffer for the
// int8 x int8 -> int32 and a CpuBackendContext for the accumulator
// computation.
void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
int result_stride, CpuBackendContext* context);
// Same as the function above except that vector values
// are quantized with asymmetric quantization per-batch and the matrix
// is quantized per row.

View File

@ -1152,6 +1152,21 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
EXPECT_NEAR(expected_c_float_data[i], c_float_data[i], 0.001);
}
// Call version of MatrixBatchVectorMultiplyAccumulate that uses
// CpuBackendGemm.
std::vector<int32_t> accum_scratch(a_rows * batches);
std::vector<float> c_float_data_2(a_rows * batches, 0.0);
CpuBackendContext context;
MatrixBatchVectorMultiplyAccumulate(
a_int8_data, a_rows, a_cols, b_int8_data, scaling_factor_c, batches,
accum_scratch.data(), c_float_data_2.data(),
/*result_stride=*/1, &context);
// Assert (again) we obtain the expected recovered float values.
for (int i = 0; i < a_rows * b_cols * batches; ++i) {
EXPECT_NEAR(expected_c_float_data[i], c_float_data_2[i], 0.001);
}
aligned_free(a_int8_data);
}
#endif // __ANDROID__

View File

@ -371,7 +371,7 @@ TfLiteStatus PopulateQuantizedLstmParams(
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData();
op_data->kernel_type = kTfLiteLSTMFullKernel;
context->AddTensors(context, /*tensors_to_add=*/7,
context->AddTensors(context, /*tensors_to_add=*/8,
&op_data->scratch_tensor_index);
return op_data;
}
@ -871,7 +871,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) {
node->temporaries = TfLiteIntArrayCreate(7);
node->temporaries = TfLiteIntArrayCreate(8);
} else if (is_integer) {
node->temporaries = TfLiteIntArrayCreate(6);
} else {
@ -940,7 +940,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, cell_state_quantized,
cell_state_quantized_size));
}
// Allocate temporary tensors to store scaling factors and product scaling
// factors. The latter is a convenience storage which allows to quantize
// a vector once (which produces the scaling factors) and multiply it with
@ -987,6 +986,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, recovered_cell_weights,
recovered_cell_weights_size));
}
// Allocate a temporary tensor to store accumulate values for matrix
// multiplication before multiplication by scaling factor
node->temporaries->data[7] = op_data->scratch_tensor_index + 7;
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/7);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {n_cell, n_batch};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
accum_size->data[0] = n_cell;
accum_size->data[1] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size));
}
}
if (is_integer) {
@ -1135,6 +1149,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/5);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, /*index=*/6);
TfLiteTensor* output_scratch_buffer =
GetTemporary(context, node, /*index=*/7);
return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
@ -1155,7 +1171,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
scaling_factors, prod_scaling_factors, recovered_cell_weights,
input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized,
cell_state_quantized, activation_state, cell_state, output);
cell_state_quantized, activation_state, cell_state,
output_scratch_buffer, output,
CpuBackendContext::GetFromContext(context));
} else {
TfLiteTensor* scratch0 = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch1 = GetTemporary(context, node, /*index=*/1);

View File

@ -37,6 +37,23 @@ inline float GetTensorScale(const TfLiteTensor* tensor) {
return tensor == nullptr ? 1.0f : tensor->params.scale;
}
inline void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, int32_t* scratch, float* __restrict__ result,
int result_stride, CpuBackendContext* context) {
// TODO(b/148289189) Remove when Ruy GEMV is the default.
#ifdef TFLITE_WITH_RUY_GEMV
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, scratch,
result, result_stride, context);
#else
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
result_stride);
#endif
}
// Performs an LSTM batch inference step for input specified by input_ptr.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
@ -461,7 +478,8 @@ inline void LstmStepHybrid(
float* recovered_cell_weights, int8_t* quantized_input_ptr,
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr) {
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("LstmStepHybrid");
// Since we have already checked that weights are all there or none, we
// can check the existence of only one to the get the condition.
@ -506,37 +524,41 @@ inline void LstmStepHybrid(
product_scaling_factors[b] =
scaling_factors[b] * input_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, input_gate_scratch,
/*result_stride=*/1);
MatrixBatchVectorMultiplyAccumulate(input_to_input_weights_ptr, n_cell,
n_input, quantized_input_ptr,
product_scaling_factors, n_batch,
accum_scratch_ptr, input_gate_scratch,
/*result_stride=*/1, context);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_forget_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, forget_gate_scratch,
/*result_stride=*/1);
MatrixBatchVectorMultiplyAccumulate(input_to_forget_weights_ptr, n_cell,
n_input, quantized_input_ptr,
product_scaling_factors, n_batch,
accum_scratch_ptr, forget_gate_scratch,
/*result_stride=*/1, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_cell_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
product_scaling_factors, n_batch, accum_scratch_ptr, cell_scratch,
/*result_stride=*/1, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_output_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, output_gate_scratch,
/*result_stride=*/1);
MatrixBatchVectorMultiplyAccumulate(input_to_output_weights_ptr, n_cell,
n_input, quantized_input_ptr,
product_scaling_factors, n_batch,
accum_scratch_ptr, output_gate_scratch,
/*result_stride=*/1, context);
}
// For each batch and cell: compute aux_input_weight * aux_input.
@ -555,38 +577,38 @@ inline void LstmStepHybrid(
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch,
input_gate_scratch, /*result_stride=*/1);
accum_scratch_ptr, input_gate_scratch, /*result_stride=*/1, context);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_forget_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch,
forget_gate_scratch, /*result_stride=*/1);
accum_scratch_ptr, forget_gate_scratch, /*result_stride=*/1, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_cell_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch,
/*result_stride=*/1);
quantized_aux_input_ptr, product_scaling_factors, n_batch,
accum_scratch_ptr, cell_scratch, /*result_stride=*/1, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_output_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch,
output_gate_scratch, /*result_stride=*/1);
accum_scratch_ptr, output_gate_scratch, /*result_stride=*/1, context);
}
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
@ -605,38 +627,38 @@ inline void LstmStepHybrid(
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
input_gate_scratch, /*result_stride=*/1);
accum_scratch_ptr, input_gate_scratch, /*result_stride=*/1, context);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_forget_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
forget_gate_scratch, /*result_stride=*/1);
accum_scratch_ptr, forget_gate_scratch, /*result_stride=*/1, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_cell_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
cell_scratch, /*result_stride=*/1);
accum_scratch_ptr, cell_scratch, /*result_stride=*/1, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_output_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
output_gate_scratch, /*result_stride=*/1);
accum_scratch_ptr, output_gate_scratch, /*result_stride=*/1, context);
}
// For each batch and cell: update input gate.
@ -768,11 +790,12 @@ inline void LstmStepHybrid(
scaling_factors[b] * projection_weights_scale;
}
for (int b = 0; b < n_batch; b++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
/*result_stride=*/1);
/*n_batch=*/1, accum_scratch_ptr,
output_ptr + b * output_batch_leading_dim,
/*result_stride=*/1, context);
}
}
if (params->proj_clip > 0.0) {
@ -1335,7 +1358,8 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output) {
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
TfLiteTensor* output, CpuBackendContext* context) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch;
@ -1389,60 +1413,59 @@ TfLiteStatus EvalHybrid(
}
float* output_ptr =
GetTensorData<float>(output) + t_rel * output_step + output_offset;
LstmStepHybrid(input_ptr, GetTensorData<int8_t>(input_to_input_weights),
GetTensorScale(input_to_input_weights),
GetTensorData<int8_t>(input_to_forget_weights),
GetTensorScale(input_to_forget_weights),
GetTensorData<int8_t>(input_to_cell_weights),
GetTensorScale(input_to_cell_weights),
GetTensorData<int8_t>(input_to_output_weights),
GetTensorScale(input_to_output_weights), aux_input_ptr,
GetTensorData<int8_t>(aux_input_to_input_weights),
GetTensorScale(aux_input_to_input_weights),
GetTensorData<int8_t>(aux_input_to_forget_weights),
GetTensorScale(aux_input_to_forget_weights),
GetTensorData<int8_t>(aux_input_to_cell_weights),
GetTensorScale(aux_input_to_cell_weights),
GetTensorData<int8_t>(aux_input_to_output_weights),
GetTensorScale(aux_input_to_output_weights),
GetTensorData<int8_t>(recurrent_to_input_weights),
GetTensorScale(recurrent_to_input_weights),
GetTensorData<int8_t>(recurrent_to_forget_weights),
GetTensorScale(recurrent_to_forget_weights),
GetTensorData<int8_t>(recurrent_to_cell_weights),
GetTensorScale(recurrent_to_cell_weights),
GetTensorData<int8_t>(recurrent_to_output_weights),
GetTensorScale(recurrent_to_output_weights),
GetTensorData<int8_t>(cell_to_input_weights),
GetTensorScale(cell_to_input_weights),
GetTensorData<int8_t>(cell_to_forget_weights),
GetTensorScale(cell_to_forget_weights),
GetTensorData<int8_t>(cell_to_output_weights),
GetTensorScale(cell_to_output_weights),
GetTensorData<float>(input_layer_norm_coefficients),
GetTensorData<float>(forget_layer_norm_coefficients),
GetTensorData<float>(cell_layer_norm_coefficients),
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<int8_t>(projection_weights),
GetTensorScale(projection_weights),
GetTensorData<float>(projection_bias), params, n_batch,
n_cell, n_input, aux_input_size, n_output,
output_batch_leading_dim, input_gate_scratch,
forget_gate_scratch, cell_scratch, output_gate_scratch,
GetTensorData<float>(scaling_factors),
GetTensorData<float>(prod_scaling_factors),
GetTensorData<float>(recovered_cell_weights),
GetTensorData<int8_t>(input_quantized),
GetTensorData<int8_t>(aux_input_quantized),
GetTensorData<int8_t>(output_state_quantized),
GetTensorData<int8_t>(cell_state_quantized),
GetTensorData<float>(output_state),
GetTensorData<float>(cell_state), output_ptr);
LstmStepHybrid(
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
GetTensorScale(input_to_input_weights),
GetTensorData<int8_t>(input_to_forget_weights),
GetTensorScale(input_to_forget_weights),
GetTensorData<int8_t>(input_to_cell_weights),
GetTensorScale(input_to_cell_weights),
GetTensorData<int8_t>(input_to_output_weights),
GetTensorScale(input_to_output_weights), aux_input_ptr,
GetTensorData<int8_t>(aux_input_to_input_weights),
GetTensorScale(aux_input_to_input_weights),
GetTensorData<int8_t>(aux_input_to_forget_weights),
GetTensorScale(aux_input_to_forget_weights),
GetTensorData<int8_t>(aux_input_to_cell_weights),
GetTensorScale(aux_input_to_cell_weights),
GetTensorData<int8_t>(aux_input_to_output_weights),
GetTensorScale(aux_input_to_output_weights),
GetTensorData<int8_t>(recurrent_to_input_weights),
GetTensorScale(recurrent_to_input_weights),
GetTensorData<int8_t>(recurrent_to_forget_weights),
GetTensorScale(recurrent_to_forget_weights),
GetTensorData<int8_t>(recurrent_to_cell_weights),
GetTensorScale(recurrent_to_cell_weights),
GetTensorData<int8_t>(recurrent_to_output_weights),
GetTensorScale(recurrent_to_output_weights),
GetTensorData<int8_t>(cell_to_input_weights),
GetTensorScale(cell_to_input_weights),
GetTensorData<int8_t>(cell_to_forget_weights),
GetTensorScale(cell_to_forget_weights),
GetTensorData<int8_t>(cell_to_output_weights),
GetTensorScale(cell_to_output_weights),
GetTensorData<float>(input_layer_norm_coefficients),
GetTensorData<float>(forget_layer_norm_coefficients),
GetTensorData<float>(cell_layer_norm_coefficients),
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<int8_t>(projection_weights),
GetTensorScale(projection_weights),
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
n_input, aux_input_size, n_output, output_batch_leading_dim,
input_gate_scratch, forget_gate_scratch, cell_scratch,
output_gate_scratch, GetTensorData<float>(scaling_factors),
GetTensorData<float>(prod_scaling_factors),
GetTensorData<float>(recovered_cell_weights),
GetTensorData<int8_t>(input_quantized),
GetTensorData<int8_t>(aux_input_quantized),
GetTensorData<int8_t>(output_state_quantized),
GetTensorData<int8_t>(cell_state_quantized),
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
GetTensorData<int32_t>(output_scratch_buffer), output_ptr, context);
}
} else {
for (int b = 0; b < n_batch; b++) {
@ -1474,59 +1497,60 @@ TfLiteStatus EvalHybrid(
float* cell_scratch_ptr = cell_scratch + b * n_cell;
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
LstmStepHybrid(input_ptr, GetTensorData<int8_t>(input_to_input_weights),
GetTensorScale(input_to_input_weights),
GetTensorData<int8_t>(input_to_forget_weights),
GetTensorScale(input_to_forget_weights),
GetTensorData<int8_t>(input_to_cell_weights),
GetTensorScale(input_to_cell_weights),
GetTensorData<int8_t>(input_to_output_weights),
GetTensorScale(input_to_output_weights), aux_input_ptr,
GetTensorData<int8_t>(aux_input_to_input_weights),
GetTensorScale(aux_input_to_input_weights),
GetTensorData<int8_t>(aux_input_to_forget_weights),
GetTensorScale(aux_input_to_forget_weights),
GetTensorData<int8_t>(aux_input_to_cell_weights),
GetTensorScale(aux_input_to_cell_weights),
GetTensorData<int8_t>(aux_input_to_output_weights),
GetTensorScale(aux_input_to_output_weights),
GetTensorData<int8_t>(recurrent_to_input_weights),
GetTensorScale(recurrent_to_input_weights),
GetTensorData<int8_t>(recurrent_to_forget_weights),
GetTensorScale(recurrent_to_forget_weights),
GetTensorData<int8_t>(recurrent_to_cell_weights),
GetTensorScale(recurrent_to_cell_weights),
GetTensorData<int8_t>(recurrent_to_output_weights),
GetTensorScale(recurrent_to_output_weights),
GetTensorData<int8_t>(cell_to_input_weights),
GetTensorScale(cell_to_input_weights),
GetTensorData<int8_t>(cell_to_forget_weights),
GetTensorScale(cell_to_forget_weights),
GetTensorData<int8_t>(cell_to_output_weights),
GetTensorScale(cell_to_output_weights),
GetTensorData<float>(input_layer_norm_coefficients),
GetTensorData<float>(forget_layer_norm_coefficients),
GetTensorData<float>(cell_layer_norm_coefficients),
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<int8_t>(projection_weights),
GetTensorScale(projection_weights),
GetTensorData<float>(projection_bias), params,
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
output_batch_leading_dim, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr,
output_gate_scratch_ptr,
GetTensorData<float>(scaling_factors),
GetTensorData<float>(prod_scaling_factors),
GetTensorData<float>(recovered_cell_weights),
GetTensorData<int8_t>(input_quantized),
GetTensorData<int8_t>(aux_input_quantized),
GetTensorData<int8_t>(output_state_quantized),
GetTensorData<int8_t>(cell_state_quantized),
output_state_ptr, cell_state_ptr, output_ptr);
LstmStepHybrid(
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
GetTensorScale(input_to_input_weights),
GetTensorData<int8_t>(input_to_forget_weights),
GetTensorScale(input_to_forget_weights),
GetTensorData<int8_t>(input_to_cell_weights),
GetTensorScale(input_to_cell_weights),
GetTensorData<int8_t>(input_to_output_weights),
GetTensorScale(input_to_output_weights), aux_input_ptr,
GetTensorData<int8_t>(aux_input_to_input_weights),
GetTensorScale(aux_input_to_input_weights),
GetTensorData<int8_t>(aux_input_to_forget_weights),
GetTensorScale(aux_input_to_forget_weights),
GetTensorData<int8_t>(aux_input_to_cell_weights),
GetTensorScale(aux_input_to_cell_weights),
GetTensorData<int8_t>(aux_input_to_output_weights),
GetTensorScale(aux_input_to_output_weights),
GetTensorData<int8_t>(recurrent_to_input_weights),
GetTensorScale(recurrent_to_input_weights),
GetTensorData<int8_t>(recurrent_to_forget_weights),
GetTensorScale(recurrent_to_forget_weights),
GetTensorData<int8_t>(recurrent_to_cell_weights),
GetTensorScale(recurrent_to_cell_weights),
GetTensorData<int8_t>(recurrent_to_output_weights),
GetTensorScale(recurrent_to_output_weights),
GetTensorData<int8_t>(cell_to_input_weights),
GetTensorScale(cell_to_input_weights),
GetTensorData<int8_t>(cell_to_forget_weights),
GetTensorScale(cell_to_forget_weights),
GetTensorData<int8_t>(cell_to_output_weights),
GetTensorScale(cell_to_output_weights),
GetTensorData<float>(input_layer_norm_coefficients),
GetTensorData<float>(forget_layer_norm_coefficients),
GetTensorData<float>(cell_layer_norm_coefficients),
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<int8_t>(projection_weights),
GetTensorScale(projection_weights),
GetTensorData<float>(projection_bias), params,
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
output_batch_leading_dim, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
GetTensorData<float>(scaling_factors),
GetTensorData<float>(prod_scaling_factors),
GetTensorData<float>(recovered_cell_weights),
GetTensorData<int8_t>(input_quantized),
GetTensorData<int8_t>(aux_input_quantized),
GetTensorData<int8_t>(output_state_quantized),
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
output_ptr, context);
}
}
}

View File

@ -154,7 +154,8 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output);
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
TfLiteTensor* output, CpuBackendContext* context);
TfLiteStatus EvalInteger(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -90,7 +91,8 @@ enum TemporaryTensor {
kScalingFactors = 4,
kProductScalingFactors = 5,
kRecoveredCellWeights = 6,
kNumTemporaryTensors = 7
kAccumScratch = 7,
kNumTemporaryTensors
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@ -497,6 +499,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, recovered_cell_weights,
recovered_cell_weights_size));
}
// Allocate a temporary tensor to store the accumulated int32 values.
node->temporaries->data[kAccumScratch] =
scratch_tensor_index + kAccumScratch;
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
accum_scratch->type = kTfLiteInt32;
accum_scratch->allocation_type = kTfLiteArenaRw;
int accum_scratch_dims[2] = {n_cell, n_batch};
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
accum_scratch_dims)) {
TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
accum_size->data[0] = n_cell;
accum_size->data[1] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size));
}
}
return kTfLiteOk;
}
@ -615,6 +633,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/5);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, /*index=*/6);
TfLiteTensor* accum_scratch =
GetTemporary(context, node, /*index=*/kAccumScratch);
return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
@ -633,7 +653,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*output_offset=*/0, scratch_buffer, scaling_factors,
prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized,
cell_state_quantized, activation_state, cell_state, output);
cell_state_quantized, activation_state, cell_state, accum_scratch,
output, CpuBackendContext::GetFromContext(context));
}
default:
context->ReportError(context, "Type %d is not currently supported.",