Dispatch to CpuBackendGemm for Hybrid Ops.
PiperOrigin-RevId: 291814226 Change-Id: Ie7671d134082cf88e4c75f3c7da69bc2c4ae9fe0
This commit is contained in:
parent
6ff06bafbc
commit
be369f57e9
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
@ -137,8 +138,9 @@ enum TemporaryTensor {
|
||||
kScalingFactors = 7,
|
||||
kProductScalingFactors = 8,
|
||||
kRecoveredCellWeights = 9,
|
||||
kAuxInputQuantized = 10, // Optional, quantized tensor for auxiliary input.
|
||||
kNumTemporaryTensors = 11
|
||||
kAccumScratchBuffer = 10,
|
||||
kAuxInputQuantized = 11, // Optional, quantized tensor for auxiliary input.
|
||||
kNumTemporaryTensors
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
@ -726,6 +728,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
recovered_cell_weights_size));
|
||||
}
|
||||
|
||||
// Allocate a temporary tensor to store the accumulated int32 values.
|
||||
node->temporaries->data[kAccumScratchBuffer] =
|
||||
*scratch_tensor_index + kAccumScratchBuffer;
|
||||
TfLiteTensor* accum_scratch =
|
||||
GetTemporary(context, node, kAccumScratchBuffer);
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int n_cell = std::max(n_fw_cell, n_bw_cell);
|
||||
if (has_aux_input) {
|
||||
n_cell = std::max(n_cell, fw_aux_input_to_output_weights->dims->data[0]);
|
||||
n_cell = std::max(n_cell, bw_aux_input_to_output_weights->dims->data[0]);
|
||||
}
|
||||
int accum_scratch_dims[2] = {n_cell, n_batch};
|
||||
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
||||
accum_scratch_dims)) {
|
||||
TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
|
||||
accum_size->data[0] = n_cell;
|
||||
accum_size->data[1] = n_batch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||
}
|
||||
|
||||
// Only allocate a temporary tensor for quantized auxiliary input if we are
|
||||
// actually going to use it.
|
||||
if (has_aux_input) {
|
||||
@ -977,6 +1001,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* aux_input_quantized =
|
||||
use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
|
||||
: nullptr;
|
||||
TfLiteTensor* accum_scratch =
|
||||
GetTemporary(context, node, kAccumScratchBuffer);
|
||||
|
||||
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
||||
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
||||
@ -998,7 +1024,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
fw_scratch_buffer, scaling_factors, prod_scaling_factors,
|
||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||
fw_activation_state_quantized, fw_cell_state_quantized,
|
||||
fw_activation_state, fw_cell_state, fw_output);
|
||||
fw_activation_state, fw_cell_state, accum_scratch, fw_output,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||
|
||||
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
|
||||
@ -1021,7 +1048,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
bw_scratch_buffer, scaling_factors, prod_scaling_factors,
|
||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||
bw_activation_state_quantized, bw_cell_state_quantized,
|
||||
bw_activation_state, bw_cell_state, actual_bw_output);
|
||||
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
@ -71,6 +71,7 @@ struct OpData {
|
||||
int input_quantized_id = kTensorNotAllocated;
|
||||
int scaling_factors_id = kTensorNotAllocated;
|
||||
int input_offset_id = kTensorNotAllocated;
|
||||
int accum_scratch_id = kTensorNotAllocated;
|
||||
|
||||
TfLitePaddingValues padding;
|
||||
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||
@ -92,6 +93,7 @@ struct OpData {
|
||||
int32_t hwcn_weights_index;
|
||||
int32_t input_quantized_index;
|
||||
int32_t scaling_factors_index;
|
||||
int32_t accum_scratch_index;
|
||||
|
||||
int32_t input_offset_index;
|
||||
bool need_hwcn_weights = false;
|
||||
@ -262,6 +264,13 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
|
||||
}
|
||||
++temporaries_count;
|
||||
|
||||
// Allocate tensor to store the accumulators for the matrix multiply.
|
||||
data->accum_scratch_index = temporaries_count;
|
||||
if (data->accum_scratch_id == kTensorNotAllocated) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->AddTensors(context, 1, &data->accum_scratch_id));
|
||||
}
|
||||
++temporaries_count;
|
||||
if (is_per_channel) {
|
||||
data->input_offset_index = temporaries_count;
|
||||
if (data->input_offset_id == kTensorNotAllocated) {
|
||||
@ -485,6 +494,21 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
scaling_factors_size));
|
||||
}
|
||||
|
||||
node->temporaries->data[data->accum_scratch_index] = data->accum_scratch_id;
|
||||
TfLiteTensor* accum_scratch =
|
||||
GetTemporary(context, node, data->accum_scratch_index);
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {channels_out, batches};
|
||||
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
||||
accum_scratch_dims)) {
|
||||
TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
|
||||
accum_scratch_size->data[0] = channels_out;
|
||||
accum_scratch_size->data[1] = batches;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
|
||||
accum_scratch_size));
|
||||
}
|
||||
|
||||
if (is_hybrid_per_channel) {
|
||||
const auto* affine_quantization =
|
||||
reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
@ -771,7 +795,7 @@ template <KernelType kernel_type>
|
||||
void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
|
||||
TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
|
||||
TfLiteTensor* output) {
|
||||
TfLiteTensor* accum_scratch, TfLiteTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
@ -816,9 +840,11 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
op_params, scaling_factors_ptr, GetTensorShape(input),
|
||||
quantized_input_ptr_batch, GetTensorShape(filter),
|
||||
GetTensorData<int8_t>(filter), GetTensorShape(bias),
|
||||
GetTensorData<float>(bias), GetTensorShape(output),
|
||||
GetTensorData<float>(bias), GetTensorShape(accum_scratch),
|
||||
GetTensorData<int32_t>(accum_scratch), GetTensorShape(output),
|
||||
GetTensorData<float>(output), GetTensorShape(im2col),
|
||||
GetTensorData<int8_t>(im2col));
|
||||
GetTensorData<int8_t>(im2col),
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -859,8 +885,11 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
EvalHybridPerChannel<kernel_type>(context, node, params, data, input,
|
||||
filter, bias, im2col, output);
|
||||
} else {
|
||||
TfLiteTensor* accum_scratch =
|
||||
&context->tensors[node->temporaries
|
||||
->data[data->accum_scratch_index]];
|
||||
EvalHybrid<kernel_type>(context, node, params, data, input, filter,
|
||||
bias, im2col, output);
|
||||
bias, im2col, accum_scratch, output);
|
||||
}
|
||||
} else {
|
||||
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
|
||||
|
@ -116,7 +116,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// Instead, we allocate a new object to carry information from Prepare() to
|
||||
// Eval().
|
||||
auto* op_data = new OpData();
|
||||
context->AddTensors(context, /*tensors_to_add=*/2,
|
||||
context->AddTensors(context, /*tensors_to_add=*/3,
|
||||
&op_data->scratch_tensor_index);
|
||||
return op_data;
|
||||
}
|
||||
@ -183,10 +183,12 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
// If we have to perform on-the-fly quantization (with quantized weights and
|
||||
// float inputs) first we need to quantize the inputs. Allocate a temporary
|
||||
// buffer to store the intermediate quantized values.
|
||||
// Additionally, we allocate a temporary buffer to store the accumulated
|
||||
// quantized values prior to multiplication by the scaling factor.
|
||||
if (input->type == kTfLiteFloat32 &&
|
||||
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(2);
|
||||
node->temporaries = TfLiteIntArrayCreate(3);
|
||||
node->temporaries->data[0] = data->scratch_tensor_index;
|
||||
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
|
||||
@ -201,6 +203,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
int scaling_dims[1] = {batch_size};
|
||||
if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
|
||||
@ -208,6 +211,20 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||
scaling_factors_size));
|
||||
}
|
||||
|
||||
node->temporaries->data[2] = data->scratch_tensor_index + 2;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {num_units, batch_size};
|
||||
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
||||
accum_scratch_dims)) {
|
||||
TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
|
||||
accum_size->data[0] = num_units;
|
||||
accum_size->data[1] = batch_size;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||
}
|
||||
}
|
||||
|
||||
// Resize output.
|
||||
@ -341,11 +358,19 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
// Compute output += weight * quantized_input
|
||||
#ifdef TFLITE_WITH_RUY_GEMV
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2);
|
||||
int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
||||
batch_size, scratch, GetTensorData<float>(output),
|
||||
/*result_stride=*/1, CpuBackendContext::GetFromContext(context));
|
||||
#else
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
||||
batch_size, GetTensorData<float>(output),
|
||||
/*result_stride=*/1);
|
||||
|
||||
#endif
|
||||
// Apply activation function to floats.
|
||||
tensor_utils::ApplyActivationToVector(
|
||||
GetTensorData<float>(output), batch_size * num_units, params->activation,
|
||||
|
@ -639,6 +639,7 @@ cc_library(
|
||||
":neon_tensor_utils",
|
||||
":portable_tensor_utils",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
],
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <sys/types.h>
|
||||
|
||||
#include "public/gemmlowp.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
|
||||
@ -2551,8 +2552,10 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
|
||||
int stride_width, int stride_height, int pad_width,
|
||||
int pad_height, float* scaling_factors_ptr,
|
||||
float output_activation_min, float output_activation_max,
|
||||
int32_t* scratch_data, const Dims<4>& scratch_dims,
|
||||
float* output_data, const Dims<4>& output_dims,
|
||||
int8_t* im2col_data, const Dims<4>& im2col_dims) {
|
||||
int8_t* im2col_data, const Dims<4>& im2col_dims,
|
||||
CpuBackendContext* context) {
|
||||
tflite::ConvParams op_params;
|
||||
// Padding type is ignored, but still set.
|
||||
op_params.padding_type = PaddingType::kSame;
|
||||
@ -2565,8 +2568,9 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
|
||||
|
||||
HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
|
||||
input_data, DimsToShape(filter_dims), filter_data,
|
||||
DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
|
||||
output_data, DimsToShape(im2col_dims), im2col_data);
|
||||
DimsToShape(bias_dims), bias_data, DimsToShape(scratch_dims),
|
||||
scratch_data, DimsToShape(output_dims), output_data,
|
||||
DimsToShape(im2col_dims), im2col_data, context);
|
||||
}
|
||||
|
||||
template <FusedActivationFunctionType Ac>
|
||||
|
@ -974,7 +974,9 @@ void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias,
|
||||
dst_params.cols = n_batch;
|
||||
|
||||
GemmParams<int32, int32> gemm_params;
|
||||
gemm_params.bias = bias;
|
||||
if (bias) {
|
||||
gemm_params.bias = bias;
|
||||
}
|
||||
cpu_backend_gemm::Gemm(lhs_params, input_to_gate_weights, rhs_params, input,
|
||||
dst_params, scratch, gemm_params, context);
|
||||
}
|
||||
@ -1145,6 +1147,48 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
free(aligned_vec_free);
|
||||
}
|
||||
|
||||
void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
int result_stride, CpuBackendContext* context) {
|
||||
if (m_rows % 4 == 0 && result_stride == 1) {
|
||||
const int32_t* bias = static_cast<const int32_t*>(nullptr);
|
||||
NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
|
||||
/*output_zp =*/0, scratch, context);
|
||||
|
||||
// Multiply by float scaling factors and write to result
|
||||
const int total_size = n_batch * m_rows;
|
||||
int i = 0;
|
||||
for (; i <= total_size - 8; i += 8, result += 8 * result_stride) {
|
||||
const float batch_scaling_factor0 = scaling_factors[i / m_rows];
|
||||
const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
|
||||
const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
|
||||
const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
|
||||
const int32x4_t scratch_val0 = vld1q_s32(scratch + i);
|
||||
const int32x4_t scratch_val1 = vld1q_s32(scratch + i + 4);
|
||||
const float32x4_t float_val0 = vcvtq_f32_s32(scratch_val0);
|
||||
const float32x4_t float_val1 = vcvtq_f32_s32(scratch_val1);
|
||||
const float32x4_t result0 =
|
||||
vmlaq_f32(vld1q_f32(result), float_val0, scaling_factor0);
|
||||
const float32x4_t result1 = vmlaq_f32(
|
||||
vld1q_f32(result + 4 * result_stride), float_val1, scaling_factor1);
|
||||
vst1q_f32(result, result0);
|
||||
vst1q_f32(result + 4 * result_stride, result1);
|
||||
}
|
||||
scratch += i;
|
||||
for (; i < total_size; i++, result += result_stride) {
|
||||
const float batch_scaling_factor = scaling_factors[i / m_rows];
|
||||
int32_t x = *(scratch++);
|
||||
*result += x * batch_scaling_factor;
|
||||
}
|
||||
return;
|
||||
}
|
||||
NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
||||
scaling_factors, n_batch, result,
|
||||
result_stride);
|
||||
}
|
||||
|
||||
void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
|
||||
int32_t n_row, int32_t n_col,
|
||||
int32_t* output) {
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
// TODO(ghodrat): Remove this header file and the dependency to internal data
|
||||
// structure.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
|
||||
@ -41,6 +42,16 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
vectors, scaling_factors, n_batch, result, result_stride);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
int result_stride, CpuBackendContext* context) {
|
||||
NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
|
||||
vectors, scaling_factors, n_batch, scratch, result,
|
||||
result_stride, context);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
// TODO(ghodrat): Remove this header file and the dependency to internal data
|
||||
// structure.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
@ -42,6 +43,14 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, int result_stride);
|
||||
|
||||
// Same as above but with a scratch buffer and CpuBackendContext for the
|
||||
// int8 x int8 -> int32 accumulation computation
|
||||
void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
int result_stride, CpuBackendContext* context);
|
||||
|
||||
// Matrix multiplication for quantized values using asymmetric quantization.
|
||||
void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
|
@ -1246,8 +1246,10 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
|
||||
const RuntimeShape& filter_shape,
|
||||
const int8_t* filter_data,
|
||||
const RuntimeShape& bias_shape, const float* bias_data,
|
||||
const RuntimeShape& output_shape, float* output_data,
|
||||
const RuntimeShape& im2col_shape, int8_t* im2col_data) {
|
||||
const RuntimeShape& accum_scratch_shape,
|
||||
int32_t* accum_scratch, const RuntimeShape& output_shape,
|
||||
float* output_data, const RuntimeShape& im2col_shape,
|
||||
int8_t* im2col_data, CpuBackendContext* context) {
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const float output_activation_min = params.float_activation_min;
|
||||
@ -1310,11 +1312,17 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
|
||||
|
||||
std::fill_n(output_data, output_rows * output_cols, 0.0f);
|
||||
|
||||
#ifdef TFLITE_WITH_RUY_GEMV
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
filter_data, filter_rows, filter_cols, gemm_input_data,
|
||||
scaling_factors_ptr, /*n_batch=*/gemm_input_rows, accum_scratch,
|
||||
output_data, /*result_stride=*/1, context);
|
||||
#else
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
filter_data, filter_rows, filter_cols, gemm_input_data,
|
||||
scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
|
||||
/*result_stride=*/1);
|
||||
|
||||
#endif
|
||||
AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
|
||||
bias_shape, bias_data, output_shape,
|
||||
output_data);
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
// TODO(ghodrat): Remove this header file and the dependency to internal data
|
||||
// structure.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/sse_check.h"
|
||||
@ -52,6 +53,15 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
vectors, scaling_factors, n_batch, result, result_stride);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
int result_stride, CpuBackendContext* context) {
|
||||
SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
|
||||
vectors, scaling_factors, n_batch, result, result_stride);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
|
@ -76,6 +76,16 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
result_stride);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vector, const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
int result_stride, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector,
|
||||
scaling_factors, n_batch, result,
|
||||
result_stride);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
|
@ -100,6 +100,15 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, int result_stride);
|
||||
|
||||
// Same as the function above, but provide a scratch buffer for the
|
||||
// int8 x int8 -> int32 and a CpuBackendContext for the accumulator
|
||||
// computation.
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
int result_stride, CpuBackendContext* context);
|
||||
|
||||
// Same as the function above except that vector values
|
||||
// are quantized with asymmetric quantization per-batch and the matrix
|
||||
// is quantized per row.
|
||||
|
@ -1152,6 +1152,21 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) {
|
||||
EXPECT_NEAR(expected_c_float_data[i], c_float_data[i], 0.001);
|
||||
}
|
||||
|
||||
// Call version of MatrixBatchVectorMultiplyAccumulate that uses
|
||||
// CpuBackendGemm.
|
||||
std::vector<int32_t> accum_scratch(a_rows * batches);
|
||||
std::vector<float> c_float_data_2(a_rows * batches, 0.0);
|
||||
CpuBackendContext context;
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
a_int8_data, a_rows, a_cols, b_int8_data, scaling_factor_c, batches,
|
||||
accum_scratch.data(), c_float_data_2.data(),
|
||||
/*result_stride=*/1, &context);
|
||||
|
||||
// Assert (again) we obtain the expected recovered float values.
|
||||
for (int i = 0; i < a_rows * b_cols * batches; ++i) {
|
||||
EXPECT_NEAR(expected_c_float_data[i], c_float_data_2[i], 0.001);
|
||||
}
|
||||
|
||||
aligned_free(a_int8_data);
|
||||
}
|
||||
#endif // __ANDROID__
|
||||
|
@ -371,7 +371,7 @@ TfLiteStatus PopulateQuantizedLstmParams(
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* op_data = new OpData();
|
||||
op_data->kernel_type = kTfLiteLSTMFullKernel;
|
||||
context->AddTensors(context, /*tensors_to_add=*/7,
|
||||
context->AddTensors(context, /*tensors_to_add=*/8,
|
||||
&op_data->scratch_tensor_index);
|
||||
return op_data;
|
||||
}
|
||||
@ -871,7 +871,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
if (is_hybrid_op) {
|
||||
node->temporaries = TfLiteIntArrayCreate(7);
|
||||
node->temporaries = TfLiteIntArrayCreate(8);
|
||||
} else if (is_integer) {
|
||||
node->temporaries = TfLiteIntArrayCreate(6);
|
||||
} else {
|
||||
@ -940,7 +940,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ResizeTensor(context, cell_state_quantized,
|
||||
cell_state_quantized_size));
|
||||
}
|
||||
|
||||
// Allocate temporary tensors to store scaling factors and product scaling
|
||||
// factors. The latter is a convenience storage which allows to quantize
|
||||
// a vector once (which produces the scaling factors) and multiply it with
|
||||
@ -987,6 +986,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ResizeTensor(context, recovered_cell_weights,
|
||||
recovered_cell_weights_size));
|
||||
}
|
||||
// Allocate a temporary tensor to store accumulate values for matrix
|
||||
// multiplication before multiplication by scaling factor
|
||||
node->temporaries->data[7] = op_data->scratch_tensor_index + 7;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/7);
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {n_cell, n_batch};
|
||||
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
||||
accum_scratch_dims)) {
|
||||
TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
|
||||
accum_size->data[0] = n_cell;
|
||||
accum_size->data[1] = n_batch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||
}
|
||||
}
|
||||
|
||||
if (is_integer) {
|
||||
@ -1135,6 +1149,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, /*index=*/6);
|
||||
TfLiteTensor* output_scratch_buffer =
|
||||
GetTemporary(context, node, /*index=*/7);
|
||||
return lstm_eval::EvalHybrid(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
@ -1155,7 +1171,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
scaling_factors, prod_scaling_factors, recovered_cell_weights,
|
||||
input_quantized,
|
||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||
cell_state_quantized, activation_state, cell_state, output);
|
||||
cell_state_quantized, activation_state, cell_state,
|
||||
output_scratch_buffer, output,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
} else {
|
||||
TfLiteTensor* scratch0 = GetTemporary(context, node, /*index=*/0);
|
||||
TfLiteTensor* scratch1 = GetTemporary(context, node, /*index=*/1);
|
||||
|
@ -37,6 +37,23 @@ inline float GetTensorScale(const TfLiteTensor* tensor) {
|
||||
return tensor == nullptr ? 1.0f : tensor->params.scale;
|
||||
}
|
||||
|
||||
inline void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
int result_stride, CpuBackendContext* context) {
|
||||
// TODO(b/148289189) Remove when Ruy GEMV is the default.
|
||||
#ifdef TFLITE_WITH_RUY_GEMV
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, scratch,
|
||||
result, result_stride, context);
|
||||
#else
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
|
||||
result_stride);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Performs an LSTM batch inference step for input specified by input_ptr.
|
||||
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
|
||||
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
|
||||
@ -461,7 +478,8 @@ inline void LstmStepHybrid(
|
||||
float* recovered_cell_weights, int8_t* quantized_input_ptr,
|
||||
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* output_ptr) {
|
||||
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
|
||||
CpuBackendContext* context) {
|
||||
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
||||
// Since we have already checked that weights are all there or none, we
|
||||
// can check the existence of only one to the get the condition.
|
||||
@ -506,37 +524,41 @@ inline void LstmStepHybrid(
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch, input_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
MatrixBatchVectorMultiplyAccumulate(input_to_input_weights_ptr, n_cell,
|
||||
n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch,
|
||||
accum_scratch_ptr, input_gate_scratch,
|
||||
/*result_stride=*/1, context);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch, forget_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
MatrixBatchVectorMultiplyAccumulate(input_to_forget_weights_ptr, n_cell,
|
||||
n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch,
|
||||
accum_scratch_ptr, forget_gate_scratch,
|
||||
/*result_stride=*/1, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
|
||||
product_scaling_factors, n_batch, accum_scratch_ptr, cell_scratch,
|
||||
/*result_stride=*/1, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch, output_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
MatrixBatchVectorMultiplyAccumulate(input_to_output_weights_ptr, n_cell,
|
||||
n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch,
|
||||
accum_scratch_ptr, output_gate_scratch,
|
||||
/*result_stride=*/1, context);
|
||||
}
|
||||
|
||||
// For each batch and cell: compute aux_input_weight * aux_input.
|
||||
@ -555,38 +577,38 @@ inline void LstmStepHybrid(
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
accum_scratch_ptr, input_gate_scratch, /*result_stride=*/1, context);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*result_stride=*/1);
|
||||
accum_scratch_ptr, forget_gate_scratch, /*result_stride=*/1, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch,
|
||||
/*result_stride=*/1);
|
||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||
accum_scratch_ptr, cell_scratch, /*result_stride=*/1, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||
output_gate_scratch, /*result_stride=*/1);
|
||||
accum_scratch_ptr, output_gate_scratch, /*result_stride=*/1, context);
|
||||
}
|
||||
|
||||
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
|
||||
@ -605,38 +627,38 @@ inline void LstmStepHybrid(
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_input_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
accum_scratch_ptr, input_gate_scratch, /*result_stride=*/1, context);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*result_stride=*/1);
|
||||
accum_scratch_ptr, forget_gate_scratch, /*result_stride=*/1, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
cell_scratch, /*result_stride=*/1);
|
||||
accum_scratch_ptr, cell_scratch, /*result_stride=*/1, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_output_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
output_gate_scratch, /*result_stride=*/1);
|
||||
accum_scratch_ptr, output_gate_scratch, /*result_stride=*/1, context);
|
||||
}
|
||||
|
||||
// For each batch and cell: update input gate.
|
||||
@ -768,11 +790,12 @@ inline void LstmStepHybrid(
|
||||
scaling_factors[b] * projection_weights_scale;
|
||||
}
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights_ptr, n_output, n_cell,
|
||||
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
|
||||
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
|
||||
/*result_stride=*/1);
|
||||
/*n_batch=*/1, accum_scratch_ptr,
|
||||
output_ptr + b * output_batch_leading_dim,
|
||||
/*result_stride=*/1, context);
|
||||
}
|
||||
}
|
||||
if (params->proj_clip > 0.0) {
|
||||
@ -1335,7 +1358,8 @@ TfLiteStatus EvalHybrid(
|
||||
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
|
||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output) {
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
||||
TfLiteTensor* output, CpuBackendContext* context) {
|
||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||
const int n_input = input->dims->data[input->dims->size - 1];
|
||||
int max_time, n_batch;
|
||||
@ -1389,60 +1413,59 @@ TfLiteStatus EvalHybrid(
|
||||
}
|
||||
float* output_ptr =
|
||||
GetTensorData<float>(output) + t_rel * output_step + output_offset;
|
||||
|
||||
LstmStepHybrid(input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
||||
GetTensorScale(input_to_input_weights),
|
||||
GetTensorData<int8_t>(input_to_forget_weights),
|
||||
GetTensorScale(input_to_forget_weights),
|
||||
GetTensorData<int8_t>(input_to_cell_weights),
|
||||
GetTensorScale(input_to_cell_weights),
|
||||
GetTensorData<int8_t>(input_to_output_weights),
|
||||
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
||||
GetTensorData<int8_t>(aux_input_to_input_weights),
|
||||
GetTensorScale(aux_input_to_input_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_forget_weights),
|
||||
GetTensorScale(aux_input_to_forget_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_cell_weights),
|
||||
GetTensorScale(aux_input_to_cell_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_output_weights),
|
||||
GetTensorScale(aux_input_to_output_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||
GetTensorScale(recurrent_to_input_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||
GetTensorScale(recurrent_to_forget_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||
GetTensorScale(recurrent_to_cell_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||
GetTensorScale(recurrent_to_output_weights),
|
||||
GetTensorData<int8_t>(cell_to_input_weights),
|
||||
GetTensorScale(cell_to_input_weights),
|
||||
GetTensorData<int8_t>(cell_to_forget_weights),
|
||||
GetTensorScale(cell_to_forget_weights),
|
||||
GetTensorData<int8_t>(cell_to_output_weights),
|
||||
GetTensorScale(cell_to_output_weights),
|
||||
GetTensorData<float>(input_layer_norm_coefficients),
|
||||
GetTensorData<float>(forget_layer_norm_coefficients),
|
||||
GetTensorData<float>(cell_layer_norm_coefficients),
|
||||
GetTensorData<float>(output_layer_norm_coefficients),
|
||||
GetTensorData<float>(input_gate_bias),
|
||||
GetTensorData<float>(forget_gate_bias),
|
||||
GetTensorData<float>(cell_bias),
|
||||
GetTensorData<float>(output_gate_bias),
|
||||
GetTensorData<int8_t>(projection_weights),
|
||||
GetTensorScale(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params, n_batch,
|
||||
n_cell, n_input, aux_input_size, n_output,
|
||||
output_batch_leading_dim, input_gate_scratch,
|
||||
forget_gate_scratch, cell_scratch, output_gate_scratch,
|
||||
GetTensorData<float>(scaling_factors),
|
||||
GetTensorData<float>(prod_scaling_factors),
|
||||
GetTensorData<float>(recovered_cell_weights),
|
||||
GetTensorData<int8_t>(input_quantized),
|
||||
GetTensorData<int8_t>(aux_input_quantized),
|
||||
GetTensorData<int8_t>(output_state_quantized),
|
||||
GetTensorData<int8_t>(cell_state_quantized),
|
||||
GetTensorData<float>(output_state),
|
||||
GetTensorData<float>(cell_state), output_ptr);
|
||||
LstmStepHybrid(
|
||||
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
||||
GetTensorScale(input_to_input_weights),
|
||||
GetTensorData<int8_t>(input_to_forget_weights),
|
||||
GetTensorScale(input_to_forget_weights),
|
||||
GetTensorData<int8_t>(input_to_cell_weights),
|
||||
GetTensorScale(input_to_cell_weights),
|
||||
GetTensorData<int8_t>(input_to_output_weights),
|
||||
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
||||
GetTensorData<int8_t>(aux_input_to_input_weights),
|
||||
GetTensorScale(aux_input_to_input_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_forget_weights),
|
||||
GetTensorScale(aux_input_to_forget_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_cell_weights),
|
||||
GetTensorScale(aux_input_to_cell_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_output_weights),
|
||||
GetTensorScale(aux_input_to_output_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||
GetTensorScale(recurrent_to_input_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||
GetTensorScale(recurrent_to_forget_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||
GetTensorScale(recurrent_to_cell_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||
GetTensorScale(recurrent_to_output_weights),
|
||||
GetTensorData<int8_t>(cell_to_input_weights),
|
||||
GetTensorScale(cell_to_input_weights),
|
||||
GetTensorData<int8_t>(cell_to_forget_weights),
|
||||
GetTensorScale(cell_to_forget_weights),
|
||||
GetTensorData<int8_t>(cell_to_output_weights),
|
||||
GetTensorScale(cell_to_output_weights),
|
||||
GetTensorData<float>(input_layer_norm_coefficients),
|
||||
GetTensorData<float>(forget_layer_norm_coefficients),
|
||||
GetTensorData<float>(cell_layer_norm_coefficients),
|
||||
GetTensorData<float>(output_layer_norm_coefficients),
|
||||
GetTensorData<float>(input_gate_bias),
|
||||
GetTensorData<float>(forget_gate_bias),
|
||||
GetTensorData<float>(cell_bias),
|
||||
GetTensorData<float>(output_gate_bias),
|
||||
GetTensorData<int8_t>(projection_weights),
|
||||
GetTensorScale(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
|
||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||
input_gate_scratch, forget_gate_scratch, cell_scratch,
|
||||
output_gate_scratch, GetTensorData<float>(scaling_factors),
|
||||
GetTensorData<float>(prod_scaling_factors),
|
||||
GetTensorData<float>(recovered_cell_weights),
|
||||
GetTensorData<int8_t>(input_quantized),
|
||||
GetTensorData<int8_t>(aux_input_quantized),
|
||||
GetTensorData<int8_t>(output_state_quantized),
|
||||
GetTensorData<int8_t>(cell_state_quantized),
|
||||
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
|
||||
GetTensorData<int32_t>(output_scratch_buffer), output_ptr, context);
|
||||
}
|
||||
} else {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
@ -1474,59 +1497,60 @@ TfLiteStatus EvalHybrid(
|
||||
float* cell_scratch_ptr = cell_scratch + b * n_cell;
|
||||
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
|
||||
|
||||
LstmStepHybrid(input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
||||
GetTensorScale(input_to_input_weights),
|
||||
GetTensorData<int8_t>(input_to_forget_weights),
|
||||
GetTensorScale(input_to_forget_weights),
|
||||
GetTensorData<int8_t>(input_to_cell_weights),
|
||||
GetTensorScale(input_to_cell_weights),
|
||||
GetTensorData<int8_t>(input_to_output_weights),
|
||||
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
||||
GetTensorData<int8_t>(aux_input_to_input_weights),
|
||||
GetTensorScale(aux_input_to_input_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_forget_weights),
|
||||
GetTensorScale(aux_input_to_forget_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_cell_weights),
|
||||
GetTensorScale(aux_input_to_cell_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_output_weights),
|
||||
GetTensorScale(aux_input_to_output_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||
GetTensorScale(recurrent_to_input_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||
GetTensorScale(recurrent_to_forget_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||
GetTensorScale(recurrent_to_cell_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||
GetTensorScale(recurrent_to_output_weights),
|
||||
GetTensorData<int8_t>(cell_to_input_weights),
|
||||
GetTensorScale(cell_to_input_weights),
|
||||
GetTensorData<int8_t>(cell_to_forget_weights),
|
||||
GetTensorScale(cell_to_forget_weights),
|
||||
GetTensorData<int8_t>(cell_to_output_weights),
|
||||
GetTensorScale(cell_to_output_weights),
|
||||
GetTensorData<float>(input_layer_norm_coefficients),
|
||||
GetTensorData<float>(forget_layer_norm_coefficients),
|
||||
GetTensorData<float>(cell_layer_norm_coefficients),
|
||||
GetTensorData<float>(output_layer_norm_coefficients),
|
||||
GetTensorData<float>(input_gate_bias),
|
||||
GetTensorData<float>(forget_gate_bias),
|
||||
GetTensorData<float>(cell_bias),
|
||||
GetTensorData<float>(output_gate_bias),
|
||||
GetTensorData<int8_t>(projection_weights),
|
||||
GetTensorScale(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params,
|
||||
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
|
||||
output_batch_leading_dim, input_gate_scratch_ptr,
|
||||
forget_gate_scratch_ptr, cell_scratch_ptr,
|
||||
output_gate_scratch_ptr,
|
||||
GetTensorData<float>(scaling_factors),
|
||||
GetTensorData<float>(prod_scaling_factors),
|
||||
GetTensorData<float>(recovered_cell_weights),
|
||||
GetTensorData<int8_t>(input_quantized),
|
||||
GetTensorData<int8_t>(aux_input_quantized),
|
||||
GetTensorData<int8_t>(output_state_quantized),
|
||||
GetTensorData<int8_t>(cell_state_quantized),
|
||||
output_state_ptr, cell_state_ptr, output_ptr);
|
||||
LstmStepHybrid(
|
||||
input_ptr, GetTensorData<int8_t>(input_to_input_weights),
|
||||
GetTensorScale(input_to_input_weights),
|
||||
GetTensorData<int8_t>(input_to_forget_weights),
|
||||
GetTensorScale(input_to_forget_weights),
|
||||
GetTensorData<int8_t>(input_to_cell_weights),
|
||||
GetTensorScale(input_to_cell_weights),
|
||||
GetTensorData<int8_t>(input_to_output_weights),
|
||||
GetTensorScale(input_to_output_weights), aux_input_ptr,
|
||||
GetTensorData<int8_t>(aux_input_to_input_weights),
|
||||
GetTensorScale(aux_input_to_input_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_forget_weights),
|
||||
GetTensorScale(aux_input_to_forget_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_cell_weights),
|
||||
GetTensorScale(aux_input_to_cell_weights),
|
||||
GetTensorData<int8_t>(aux_input_to_output_weights),
|
||||
GetTensorScale(aux_input_to_output_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_input_weights),
|
||||
GetTensorScale(recurrent_to_input_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_forget_weights),
|
||||
GetTensorScale(recurrent_to_forget_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_cell_weights),
|
||||
GetTensorScale(recurrent_to_cell_weights),
|
||||
GetTensorData<int8_t>(recurrent_to_output_weights),
|
||||
GetTensorScale(recurrent_to_output_weights),
|
||||
GetTensorData<int8_t>(cell_to_input_weights),
|
||||
GetTensorScale(cell_to_input_weights),
|
||||
GetTensorData<int8_t>(cell_to_forget_weights),
|
||||
GetTensorScale(cell_to_forget_weights),
|
||||
GetTensorData<int8_t>(cell_to_output_weights),
|
||||
GetTensorScale(cell_to_output_weights),
|
||||
GetTensorData<float>(input_layer_norm_coefficients),
|
||||
GetTensorData<float>(forget_layer_norm_coefficients),
|
||||
GetTensorData<float>(cell_layer_norm_coefficients),
|
||||
GetTensorData<float>(output_layer_norm_coefficients),
|
||||
GetTensorData<float>(input_gate_bias),
|
||||
GetTensorData<float>(forget_gate_bias),
|
||||
GetTensorData<float>(cell_bias),
|
||||
GetTensorData<float>(output_gate_bias),
|
||||
GetTensorData<int8_t>(projection_weights),
|
||||
GetTensorScale(projection_weights),
|
||||
GetTensorData<float>(projection_bias), params,
|
||||
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
|
||||
output_batch_leading_dim, input_gate_scratch_ptr,
|
||||
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
|
||||
GetTensorData<float>(scaling_factors),
|
||||
GetTensorData<float>(prod_scaling_factors),
|
||||
GetTensorData<float>(recovered_cell_weights),
|
||||
GetTensorData<int8_t>(input_quantized),
|
||||
GetTensorData<int8_t>(aux_input_quantized),
|
||||
GetTensorData<int8_t>(output_state_quantized),
|
||||
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
|
||||
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
|
||||
output_ptr, context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,7 +154,8 @@ TfLiteStatus EvalHybrid(
|
||||
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
|
||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output);
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
||||
TfLiteTensor* output, CpuBackendContext* context);
|
||||
|
||||
TfLiteStatus EvalInteger(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
@ -90,7 +91,8 @@ enum TemporaryTensor {
|
||||
kScalingFactors = 4,
|
||||
kProductScalingFactors = 5,
|
||||
kRecoveredCellWeights = 6,
|
||||
kNumTemporaryTensors = 7
|
||||
kAccumScratch = 7,
|
||||
kNumTemporaryTensors
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
@ -497,6 +499,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ResizeTensor(context, recovered_cell_weights,
|
||||
recovered_cell_weights_size));
|
||||
}
|
||||
|
||||
// Allocate a temporary tensor to store the accumulated int32 values.
|
||||
node->temporaries->data[kAccumScratch] =
|
||||
scratch_tensor_index + kAccumScratch;
|
||||
TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch);
|
||||
accum_scratch->type = kTfLiteInt32;
|
||||
accum_scratch->allocation_type = kTfLiteArenaRw;
|
||||
int accum_scratch_dims[2] = {n_cell, n_batch};
|
||||
if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
|
||||
accum_scratch_dims)) {
|
||||
TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
|
||||
accum_size->data[0] = n_cell;
|
||||
accum_size->data[1] = n_batch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -615,6 +633,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, /*index=*/6);
|
||||
TfLiteTensor* accum_scratch =
|
||||
GetTemporary(context, node, /*index=*/kAccumScratch);
|
||||
return lstm_eval::EvalHybrid(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
@ -633,7 +653,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
/*output_offset=*/0, scratch_buffer, scaling_factors,
|
||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||
cell_state_quantized, activation_state, cell_state, output);
|
||||
cell_state_quantized, activation_state, cell_state, accum_scratch,
|
||||
output, CpuBackendContext::GetFromContext(context));
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Type %d is not currently supported.",
|
||||
|
Loading…
Reference in New Issue
Block a user