Use persistent buffer in reference svdf.
PiperOrigin-RevId: 309293414 Change-Id: I33e9b80fc7f634968a7a11ef68a34550c367b92d
This commit is contained in:
parent
9111f46cc5
commit
e9edfeff0b
@ -331,8 +331,20 @@ constexpr int kInputActivationStateTensor = 4;
|
||||
// Output tensor.
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
void* data = nullptr;
|
||||
if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) ==
|
||||
kTfLiteError) {
|
||||
return nullptr;
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
|
||||
|
||||
// Validate Tensor Inputs (dtype depends on quantization):
|
||||
// [0] = Input, {2, batch_size, input_size}
|
||||
@ -341,7 +353,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// [3] = Bias (optional), {1, num_units}
|
||||
// [4] = Activation State (variable),
|
||||
// {2, batch_size, memory_size * num_filters}
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* weights_feature =
|
||||
GetInput(context, node, kWeightsFeatureTensor);
|
||||
@ -360,8 +371,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int num_units = num_filters / rank;
|
||||
const int memory_size = weights_time->dims->data[1];
|
||||
|
||||
const bool is_full_integer = input->type == kTfLiteInt8;
|
||||
|
||||
// Validate Input Tensor:
|
||||
TF_LITE_ENSURE(context,
|
||||
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
|
||||
@ -385,7 +394,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
|
||||
|
||||
// Validate Optional Bias Input Tensor:
|
||||
if (bias) {
|
||||
if (bias != nullptr) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
|
||||
}
|
||||
|
||||
@ -395,50 +404,52 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
|
||||
memory_size * num_filters);
|
||||
|
||||
if (is_full_integer) {
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
|
||||
|
||||
if (bias) {
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
|
||||
if (bias != nullptr) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
|
||||
|
||||
// Validate Scratch Tensors:
|
||||
// [0] = (shared - see float block below for usage)
|
||||
// [1] = Output Temp, int8_t, {2, num_units, batch_size}
|
||||
// TODO(b/132070898): Scratch values are used as stack variables in
|
||||
// EvalIntegerSVDF().
|
||||
|
||||
// Validate output tensor:
|
||||
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8);
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
|
||||
|
||||
// Validate Input Tensor dtypes:
|
||||
const auto* input_params =
|
||||
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
|
||||
const auto* weights_feature_params =
|
||||
static_cast<const TfLiteAffineQuantization*>(
|
||||
weights_feature->quantization.params);
|
||||
const auto* state_params = static_cast<const TfLiteAffineQuantization*>(
|
||||
activation_state->quantization.params);
|
||||
const auto* weight_time_params =
|
||||
static_cast<const TfLiteAffineQuantization*>(
|
||||
weights_time->quantization.params);
|
||||
const auto* output_params = static_cast<const TfLiteAffineQuantization*>(
|
||||
output->quantization.params);
|
||||
const double effective_scale_1 = static_cast<double>(
|
||||
input_params->scale->data[0] * weights_feature_params->scale->data[0] /
|
||||
state_params->scale->data[0]);
|
||||
const double effective_scale_2 = static_cast<double>(
|
||||
state_params->scale->data[0] * weight_time_params->scale->data[0] /
|
||||
output_params->scale->data[0]);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
|
||||
&(data->effective_scale_1_b));
|
||||
QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
|
||||
&(data->effective_scale_2_b));
|
||||
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
|
||||
|
||||
if (bias) {
|
||||
if (bias != nullptr) {
|
||||
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
// Validate shared Scratch Tensor:
|
||||
// [0] = Holds dot-product of time-forward calculations in
|
||||
// ApplyTimeWeightsBiasAndActivation():
|
||||
// float/int32, {2, batch_size, num_filters}
|
||||
// TODO(b/132070898): Scratch values are used as stack variables in
|
||||
// EvalIntegerSVDF().
|
||||
|
||||
// Full-float SVDF only uses the one shared scratch tensor (see above for
|
||||
// usage).
|
||||
// TODO(b/132070898): Use input tensor as variable until scratch tensor
|
||||
// allocation has been implemented.
|
||||
// TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||
}
|
||||
|
||||
@ -458,13 +469,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
const bool is_full_integer = input->type == kTfLiteInt8;
|
||||
|
||||
switch (weights_feature->type) {
|
||||
case kTfLiteFloat32: {
|
||||
// TODO(b/132070898): Use input tensor as variable until scratch tensor
|
||||
// allocation has been implemented.
|
||||
// TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
|
||||
EvalFloatSVDF(context, node, input, weights_feature, weights_time, bias,
|
||||
params, activation_state, output);
|
||||
return kTfLiteOk;
|
||||
@ -472,43 +478,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
case kTfLiteInt8: {
|
||||
if (is_full_integer) {
|
||||
// TODO(b/132070898): Store these values in ::Prepare() instead of
|
||||
// ::Eval():
|
||||
// Calculate effective scales.
|
||||
OpData op_data;
|
||||
auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
input->quantization.params);
|
||||
auto* weights_feature_params =
|
||||
reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
weights_feature->quantization.params);
|
||||
auto* state_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
activation_state->quantization.params);
|
||||
auto* weight_time_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
weights_time->quantization.params);
|
||||
auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
output->quantization.params);
|
||||
const double effective_scale_1 =
|
||||
static_cast<double>(input_params->scale->data[0] *
|
||||
weights_feature_params->scale->data[0] /
|
||||
state_params->scale->data[0]);
|
||||
const double effective_scale_2 = static_cast<double>(
|
||||
state_params->scale->data[0] * weight_time_params->scale->data[0] /
|
||||
output_params->scale->data[0]);
|
||||
QuantizeMultiplier(effective_scale_1, &op_data.effective_scale_1_a,
|
||||
&op_data.effective_scale_1_b);
|
||||
QuantizeMultiplier(effective_scale_2, &op_data.effective_scale_2_a,
|
||||
&op_data.effective_scale_2_b);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
|
||||
EvalIntegerSVDF(
|
||||
context, node, input, weights_feature, weights_time, bias, params,
|
||||
activation_state, output, op_data.effective_scale_1_a,
|
||||
op_data.effective_scale_1_b, op_data.effective_scale_2_a,
|
||||
op_data.effective_scale_2_b, input->params.zero_point,
|
||||
output->params.zero_point);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||
EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
|
||||
params, activation_state, output,
|
||||
data.effective_scale_1_a, data.effective_scale_1_b,
|
||||
data.effective_scale_2_a, data.effective_scale_2_b,
|
||||
input->params.zero_point, output->params.zero_point);
|
||||
return kTfLiteOk;
|
||||
break;
|
||||
}
|
||||
|
||||
@ -523,7 +501,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace svdf
|
||||
|
||||
TfLiteRegistration* Register_SVDF() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
static TfLiteRegistration r = {/*init=*/svdf::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/svdf::Prepare,
|
||||
/*invoke=*/svdf::Eval,
|
||||
|
Loading…
Reference in New Issue
Block a user