Basic hotword model unit test for TFLite Micro.

PiperOrigin-RevId: 267664706
This commit is contained in:
Nick Kreeger 2019-09-06 13:45:59 -07:00 committed by TensorFlower Gardener
parent 354b298bd8
commit 17e730f48b
2 changed files with 20 additions and 5 deletions

View File

@ -359,7 +359,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// [3] = Bias (optional), {1, num_units}
// [4] = Activation State (variable),
// {2, batch_size, memory_size * num_filters}
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
// has been implemented (cl/263032056)
// TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
TF_LITE_ENSURE_EQ(context, node->inputs->size, 6);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
@ -408,7 +411,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// [0] = Holds dot-product of time-forward calculations in
// ApplyTimeWeightsBiasAndActivation():
// float, {2, batch_size, num_filters}
TfLiteTensor* scratch_tensor = GetTemporary(context, node, 0);
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
// has been implemented (cl/263032056)
// TfLiteTensor* scratch_tensor = GetTemporary(context, node, 0);
TfLiteTensor* scratch_tensor = &context->tensors[node->inputs->data[5]];
TF_LITE_ENSURE_EQ(context, scratch_tensor->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, NumDimensions(scratch_tensor), 2);
TF_LITE_ENSURE_EQ(context, scratch_tensor->dims->data[0], batch_size);
@ -479,7 +486,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Full-float SVDF only uses the one shared scratch tensor (see above for
// usage).
TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1);
// TODO(kreeger): Use input tensor as variable until scratch tensor
// allocation has been implemented (cl/263032056)
// TF_LITE_ENSURE_EQ(context, node->temporaries->size, 1);
}
// Validate Tensor Output:
@ -504,7 +513,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetInput(context, node, kWeightsTimeTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
// has been implemented (cl/263032056)
// TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch = &context->tensors[node->inputs->data[5]];
TfLiteTensor* activation_state =
&context->tensors[node->inputs->data[kInputActivationStateTensor]];

View File

@ -146,7 +146,10 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
}
// Bias is an optional tensor:
int inputs_array_data[] = {5, 0, 1, 2, kOptionalTensor, 3};
// TODO(kreeger): Use input tensor as variable until scratch tensor allocation
// has been implemented (cl/263032056)
// int inputs_array_data[] = {5, 0, 1, 2, kOptionalTensor, 3};
int inputs_array_data[] = {6, 0, 1, 2, kOptionalTensor, 3, 5};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 4};