Use persistent buffer in reference svdf.

PiperOrigin-RevId: 309293414
Change-Id: I33e9b80fc7f634968a7a11ef68a34550c367b92d
This commit is contained in:
Advait Jain 2020-04-30 13:54:11 -07:00 committed by TensorFlower Gardener
parent 9111f46cc5
commit e9edfeff0b

View File

@ -331,8 +331,20 @@ constexpr int kInputActivationStateTensor = 4;
// Output tensor.
constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
void* data = nullptr;
if (context->AllocatePersistentBuffer(context, sizeof(OpData), &data) ==
kTfLiteError) {
return nullptr;
}
return data;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
TFLITE_DCHECK(node->builtin_data != nullptr);
const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
// Validate Tensor Inputs (dtype depends on quantization):
// [0] = Input, {2, batch_size, input_size}
@ -341,7 +353,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// [3] = Bias (optional), {1, num_units}
// [4] = Activation State (variable),
// {2, batch_size, memory_size * num_filters}
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
@ -360,8 +371,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int num_units = num_filters / rank;
const int memory_size = weights_time->dims->data[1];
const bool is_full_integer = input->type == kTfLiteInt8;
// Validate Input Tensor:
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
@ -385,7 +394,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
// Validate Optional Bias Input Tensor:
if (bias) {
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
}
@ -395,50 +404,52 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
memory_size * num_filters);
if (is_full_integer) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
if (bias) {
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
}
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
// Validate Scratch Tensors:
// [0] = (shared - see float block below for usage)
// [1] = Output Temp, int8_t, {2, num_units, batch_size}
// TODO(b/132070898): Scratch values are used as stack variables in
// EvalIntegerSVDF().
// Validate output tensor:
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt8);
} else {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
// Validate Input Tensor dtypes:
const auto* input_params =
reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
const auto* weights_feature_params =
static_cast<const TfLiteAffineQuantization*>(
weights_feature->quantization.params);
const auto* state_params = static_cast<const TfLiteAffineQuantization*>(
activation_state->quantization.params);
const auto* weight_time_params =
static_cast<const TfLiteAffineQuantization*>(
weights_time->quantization.params);
const auto* output_params = static_cast<const TfLiteAffineQuantization*>(
output->quantization.params);
const double effective_scale_1 = static_cast<double>(
input_params->scale->data[0] * weights_feature_params->scale->data[0] /
state_params->scale->data[0]);
const double effective_scale_2 = static_cast<double>(
state_params->scale->data[0] * weight_time_params->scale->data[0] /
output_params->scale->data[0]);
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
&(data->effective_scale_1_b));
QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
&(data->effective_scale_2_b));
} else {
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
if (bias) {
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
}
// Validate shared Scratch Tensor:
// [0] = Holds dot-product of time-forward calculations in
// ApplyTimeWeightsBiasAndActivation():
// float/int32, {2, batch_size, num_filters}
// TODO(b/132070898): Scratch values are used as stack variables in
// EvalIntegerSVDF().
// Full-float SVDF only uses the one shared scratch tensor (see above for
// usage).
// TODO(b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented.
// TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
}
@ -458,13 +469,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetVariableInput(context, node, kInputActivationStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const bool is_full_integer = input->type == kTfLiteInt8;
switch (weights_feature->type) {
case kTfLiteFloat32: {
// TODO(b/132070898): Use input tensor as variable until scratch tensor
// allocation has been implemented.
// TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
EvalFloatSVDF(context, node, input, weights_feature, weights_time, bias,
params, activation_state, output);
return kTfLiteOk;
@ -472,43 +478,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
case kTfLiteInt8: {
if (is_full_integer) {
// TODO(b/132070898): Store these values in ::Prepare() instead of
// ::Eval():
// Calculate effective scales.
OpData op_data;
auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
input->quantization.params);
auto* weights_feature_params =
reinterpret_cast<TfLiteAffineQuantization*>(
weights_feature->quantization.params);
auto* state_params = reinterpret_cast<TfLiteAffineQuantization*>(
activation_state->quantization.params);
auto* weight_time_params = reinterpret_cast<TfLiteAffineQuantization*>(
weights_time->quantization.params);
auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
output->quantization.params);
const double effective_scale_1 =
static_cast<double>(input_params->scale->data[0] *
weights_feature_params->scale->data[0] /
state_params->scale->data[0]);
const double effective_scale_2 = static_cast<double>(
state_params->scale->data[0] * weight_time_params->scale->data[0] /
output_params->scale->data[0]);
QuantizeMultiplier(effective_scale_1, &op_data.effective_scale_1_a,
&op_data.effective_scale_1_b);
QuantizeMultiplier(effective_scale_2, &op_data.effective_scale_2_a,
&op_data.effective_scale_2_b);
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
EvalIntegerSVDF(
context, node, input, weights_feature, weights_time, bias, params,
activation_state, output, op_data.effective_scale_1_a,
op_data.effective_scale_1_b, op_data.effective_scale_2_a,
op_data.effective_scale_2_b, input->params.zero_point,
output->params.zero_point);
return kTfLiteOk;
}
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias,
params, activation_state, output,
data.effective_scale_1_a, data.effective_scale_1_b,
data.effective_scale_2_a, data.effective_scale_2_b,
input->params.zero_point, output->params.zero_point);
return kTfLiteOk;
break;
}
@ -523,7 +501,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace svdf
TfLiteRegistration* Register_SVDF() {
static TfLiteRegistration r = {/*init=*/nullptr,
static TfLiteRegistration r = {/*init=*/svdf::Init,
/*free=*/nullptr,
/*prepare=*/svdf::Prepare,
/*invoke=*/svdf::Eval,