Integrate speech commands results over time to give more accurate predictions
PiperOrigin-RevId: 226948751
This commit is contained in:
parent
83cb1f1c5e
commit
6d92ee85a8
@ -176,7 +176,6 @@ cc_library(
|
|||||||
":audio_provider",
|
":audio_provider",
|
||||||
":model_settings",
|
":model_settings",
|
||||||
":preprocessor_reference",
|
":preprocessor_reference",
|
||||||
":timer",
|
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||||
],
|
],
|
||||||
@ -191,7 +190,6 @@ tflite_micro_cc_test(
|
|||||||
":audio_provider",
|
":audio_provider",
|
||||||
":feature_provider",
|
":feature_provider",
|
||||||
":model_settings",
|
":model_settings",
|
||||||
":timer",
|
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||||
@ -221,6 +219,34 @@ tflite_micro_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "recognize_commands",
|
||||||
|
srcs = [
|
||||||
|
"recognize_commands.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"recognize_commands.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":model_settings",
|
||||||
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
|
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tflite_micro_cc_test(
|
||||||
|
name = "recognize_commands_test",
|
||||||
|
srcs = [
|
||||||
|
"recognize_commands_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":recognize_commands",
|
||||||
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
|
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||||
|
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "micro_speech",
|
name = "micro_speech",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -232,6 +258,7 @@ cc_binary(
|
|||||||
":features_test_data",
|
":features_test_data",
|
||||||
":model_settings",
|
":model_settings",
|
||||||
":preprocessor_reference",
|
":preprocessor_reference",
|
||||||
|
":recognize_commands",
|
||||||
":timer",
|
":timer",
|
||||||
":tiny_conv_model_data",
|
":tiny_conv_model_data",
|
||||||
"//tensorflow/lite:schema_fbs_version",
|
"//tensorflow/lite:schema_fbs_version",
|
||||||
|
@ -91,7 +91,6 @@ FEATURE_PROVIDER_TEST_SRCS := \
|
|||||||
tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc \
|
tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc \
|
||||||
tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \
|
tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \
|
||||||
tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc \
|
tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc \
|
||||||
tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc \
|
|
||||||
tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \
|
tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \
|
||||||
tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc
|
tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc
|
||||||
ALL_SRCS += $(FEATURE_PROVIDER_TEST_SRCS)
|
ALL_SRCS += $(FEATURE_PROVIDER_TEST_SRCS)
|
||||||
@ -128,6 +127,26 @@ timer_test_bin: $(TIMER_TEST_BINARY).bin
|
|||||||
test_timer: $(TIMER_TEST_BINARY)
|
test_timer: $(TIMER_TEST_BINARY)
|
||||||
$(TEST_SCRIPT) $(TIMER_TEST_BINARY) '~~~ALL TESTS PASSED~~~'
|
$(TEST_SCRIPT) $(TIMER_TEST_BINARY) '~~~ALL TESTS PASSED~~~'
|
||||||
|
|
||||||
|
# Tests the feature provider module.
|
||||||
|
RECOGNIZE_COMMANDS_TEST_SRCS := \
|
||||||
|
tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc \
|
||||||
|
tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \
|
||||||
|
tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc
|
||||||
|
ALL_SRCS += $(RECOGNIZE_COMMANDS_TEST_SRCS)
|
||||||
|
RECOGNIZE_COMMANDS_TEST_OBJS := $(addprefix $(OBJDIR), \
|
||||||
|
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(RECOGNIZE_COMMANDS_TEST_SRCS))))
|
||||||
|
RECOGNIZE_COMMANDS_TEST_BINARY := $(BINDIR)recognize_commands_test
|
||||||
|
ALL_BINARIES += $(RECOGNIZE_COMMANDS_TEST_BINARY)
|
||||||
|
$(RECOGNIZE_COMMANDS_TEST_BINARY): $(RECOGNIZE_COMMANDS_TEST_OBJS) $(MICROLITE_LIB_PATH)
|
||||||
|
@mkdir -p $(dir $@)
|
||||||
|
$(CXX) $(CXXFLAGS) $(INCLUDES) \
|
||||||
|
-o $(RECOGNIZE_COMMANDS_TEST_BINARY) $(RECOGNIZE_COMMANDS_TEST_OBJS) \
|
||||||
|
$(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS)
|
||||||
|
recognize_commands_test: $(RECOGNIZE_COMMANDS_TEST_BINARY)
|
||||||
|
recognize_commands_test_bin: $(RECOGNIZE_COMMANDS_TEST_BINARY).bin
|
||||||
|
test_recognize_commands: $(RECOGNIZE_COMMANDS_TEST_BINARY)
|
||||||
|
$(TEST_SCRIPT) $(RECOGNIZE_COMMANDS_TEST_BINARY) '~~~ALL TESTS PASSED~~~'
|
||||||
|
|
||||||
# Builds a standalone speech command recognizer binary.
|
# Builds a standalone speech command recognizer binary.
|
||||||
MICRO_SPEECH_SRCS := \
|
MICRO_SPEECH_SRCS := \
|
||||||
tensorflow/lite/experimental/micro/examples/micro_speech/main.cc \
|
tensorflow/lite/experimental/micro/examples/micro_speech/main.cc \
|
||||||
@ -138,7 +157,8 @@ tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc \
|
|||||||
tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \
|
tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \
|
||||||
tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc \
|
tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc \
|
||||||
tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc \
|
tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc \
|
||||||
tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc
|
tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc \
|
||||||
|
tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc
|
||||||
ALL_SRCS += $(MICRO_SPEECH_SRCS)
|
ALL_SRCS += $(MICRO_SPEECH_SRCS)
|
||||||
MICRO_SPEECH_OBJS := $(addprefix $(OBJDIR), \
|
MICRO_SPEECH_OBJS := $(addprefix $(OBJDIR), \
|
||||||
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_SRCS))))
|
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_SRCS))))
|
||||||
|
@ -18,20 +18,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h"
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h"
|
||||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
|
||||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h"
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h"
|
||||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// Stores the timestamp for the previous fetch of audio data, so that we can
|
|
||||||
// avoid recalculating all the features from scratch if some earlier timeslices
|
|
||||||
// are still present.
|
|
||||||
int32_t g_last_time_in_ms = 0;
|
|
||||||
// Make sure we don't try to use cached information if this is the first call
|
|
||||||
// into the provider.
|
|
||||||
bool g_is_first_run = true;
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data)
|
FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data)
|
||||||
: feature_size_(feature_size), feature_data_(feature_data) {
|
: feature_size_(feature_size),
|
||||||
|
feature_data_(feature_data),
|
||||||
|
is_first_run_(true) {
|
||||||
// Initialize the feature data to default values.
|
// Initialize the feature data to default values.
|
||||||
for (int n = 0; n < feature_size_; ++n) {
|
for (int n = 0; n < feature_size_; ++n) {
|
||||||
feature_data_[n] = 0;
|
feature_data_[n] = 0;
|
||||||
@ -41,24 +32,23 @@ FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data)
|
|||||||
FeatureProvider::~FeatureProvider() {}
|
FeatureProvider::~FeatureProvider() {}
|
||||||
|
|
||||||
TfLiteStatus FeatureProvider::PopulateFeatureData(
|
TfLiteStatus FeatureProvider::PopulateFeatureData(
|
||||||
tflite::ErrorReporter* error_reporter, int* how_many_new_slices) {
|
tflite::ErrorReporter* error_reporter, int32_t last_time_in_ms,
|
||||||
|
int32_t time_in_ms, int* how_many_new_slices) {
|
||||||
if (feature_size_ != kFeatureElementCount) {
|
if (feature_size_ != kFeatureElementCount) {
|
||||||
error_reporter->Report("Requested feature_data_ size %d doesn't match %d",
|
error_reporter->Report("Requested feature_data_ size %d doesn't match %d",
|
||||||
feature_size_, kFeatureElementCount);
|
feature_size_, kFeatureElementCount);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int32_t time_in_ms = TimeInMilliseconds();
|
|
||||||
// Quantize the time into steps as long as each window stride, so we can
|
// Quantize the time into steps as long as each window stride, so we can
|
||||||
// figure out which audio data we need to fetch.
|
// figure out which audio data we need to fetch.
|
||||||
const int last_step = (g_last_time_in_ms / kFeatureSliceStrideMs);
|
const int last_step = (last_time_in_ms / kFeatureSliceStrideMs);
|
||||||
const int current_step = (time_in_ms / kFeatureSliceStrideMs);
|
const int current_step = (time_in_ms / kFeatureSliceStrideMs);
|
||||||
g_last_time_in_ms = time_in_ms;
|
|
||||||
|
|
||||||
int slices_needed = current_step - last_step;
|
int slices_needed = current_step - last_step;
|
||||||
// If this is the first call, make sure we don't use any cached information.
|
// If this is the first call, make sure we don't use any cached information.
|
||||||
if (g_is_first_run) {
|
if (is_first_run_) {
|
||||||
g_is_first_run = false;
|
is_first_run_ = false;
|
||||||
slices_needed = kFeatureSliceCount;
|
slices_needed = kFeatureSliceCount;
|
||||||
}
|
}
|
||||||
if (slices_needed > kFeatureSliceCount) {
|
if (slices_needed > kFeatureSliceCount) {
|
||||||
|
@ -38,11 +38,15 @@ class FeatureProvider {
|
|||||||
// Fills the feature data with information from audio inputs, and returns how
|
// Fills the feature data with information from audio inputs, and returns how
|
||||||
// many feature slices were updated.
|
// many feature slices were updated.
|
||||||
TfLiteStatus PopulateFeatureData(tflite::ErrorReporter* error_reporter,
|
TfLiteStatus PopulateFeatureData(tflite::ErrorReporter* error_reporter,
|
||||||
|
int32_t last_time_in_ms, int32_t time_in_ms,
|
||||||
int* how_many_new_slices);
|
int* how_many_new_slices);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int feature_size_;
|
int feature_size_;
|
||||||
uint8_t* feature_data_;
|
uint8_t* feature_data_;
|
||||||
|
// Make sure we don't try to use cached information if this is the first call
|
||||||
|
// into the provider.
|
||||||
|
bool is_first_run_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_
|
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_
|
||||||
|
@ -30,7 +30,8 @@ TF_LITE_MICRO_TEST(TestFeatureProvider) {
|
|||||||
|
|
||||||
int how_many_new_slices = 0;
|
int how_many_new_slices = 0;
|
||||||
TfLiteStatus populate_status = feature_provider.PopulateFeatureData(
|
TfLiteStatus populate_status = feature_provider.PopulateFeatureData(
|
||||||
error_reporter, &how_many_new_slices);
|
error_reporter, /* last_time_in_ms= */ 0, /* time_in_ms= */ 10000,
|
||||||
|
&how_many_new_slices);
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, populate_status);
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, populate_status);
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kFeatureSliceCount, how_many_new_slices);
|
TF_LITE_MICRO_EXPECT_EQ(kFeatureSliceCount, how_many_new_slices);
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h"
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h"
|
||||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
|
||||||
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h"
|
||||||
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h"
|
||||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h"
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h"
|
||||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||||
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
|
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
|
||||||
@ -68,16 +70,21 @@ int main(int argc, char* argv[]) {
|
|||||||
FeatureProvider feature_provider(kFeatureElementCount,
|
FeatureProvider feature_provider(kFeatureElementCount,
|
||||||
model_input->data.uint8);
|
model_input->data.uint8);
|
||||||
|
|
||||||
|
RecognizeCommands recognizer(error_reporter);
|
||||||
|
|
||||||
|
int32_t previous_time = 0;
|
||||||
// Keep reading and analysing audio data in an infinite loop.
|
// Keep reading and analysing audio data in an infinite loop.
|
||||||
while (true) {
|
while (true) {
|
||||||
// Fetch the spectrogram for the current time.
|
// Fetch the spectrogram for the current time.
|
||||||
|
const int32_t current_time = TimeInMilliseconds();
|
||||||
int how_many_new_slices = 0;
|
int how_many_new_slices = 0;
|
||||||
TfLiteStatus feature_status = feature_provider.PopulateFeatureData(
|
TfLiteStatus feature_status = feature_provider.PopulateFeatureData(
|
||||||
error_reporter, &how_many_new_slices);
|
error_reporter, previous_time, current_time, &how_many_new_slices);
|
||||||
if (feature_status != kTfLiteOk) {
|
if (feature_status != kTfLiteOk) {
|
||||||
error_reporter->Report("Feature generation failed");
|
error_reporter->Report("Feature generation failed");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
previous_time = current_time;
|
||||||
// If no new audio samples have been received since last time, don't bother
|
// If no new audio samples have been received since last time, don't bother
|
||||||
// running the network model.
|
// running the network model.
|
||||||
if (how_many_new_slices == 0) {
|
if (how_many_new_slices == 0) {
|
||||||
@ -105,7 +112,19 @@ int main(int argc, char* argv[]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
error_reporter->Report("Heard %s", kCategoryLabels[top_category_index]);
|
const char* found_command = nullptr;
|
||||||
|
uint8_t score = 0;
|
||||||
|
bool is_new_command = false;
|
||||||
|
TfLiteStatus process_status = recognizer.ProcessLatestResults(
|
||||||
|
output, current_time, &found_command, &score, &is_new_command);
|
||||||
|
if (process_status != kTfLiteOk) {
|
||||||
|
error_reporter->Report(
|
||||||
|
"RecognizeCommands::ProcessLatestResults() failed");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (is_new_command) {
|
||||||
|
error_reporter->Report("Heard %s (%d)", found_command, score);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -0,0 +1,139 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
RecognizeCommands::RecognizeCommands(tflite::ErrorReporter* error_reporter,
|
||||||
|
int32_t average_window_duration_ms,
|
||||||
|
uint8_t detection_threshold,
|
||||||
|
int32_t suppression_ms,
|
||||||
|
int32_t minimum_count)
|
||||||
|
: error_reporter_(error_reporter),
|
||||||
|
average_window_duration_ms_(average_window_duration_ms),
|
||||||
|
detection_threshold_(detection_threshold),
|
||||||
|
suppression_ms_(suppression_ms),
|
||||||
|
minimum_count_(minimum_count),
|
||||||
|
previous_results_(error_reporter) {
|
||||||
|
previous_top_label_ = "_silence_";
|
||||||
|
previous_top_label_time_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus RecognizeCommands::ProcessLatestResults(
|
||||||
|
const TfLiteTensor* latest_results, const int32_t current_time_ms,
|
||||||
|
const char** found_command, uint8_t* score, bool* is_new_command) {
|
||||||
|
if ((latest_results->dims->size != 2) ||
|
||||||
|
(latest_results->dims->data[0] != 1) ||
|
||||||
|
(latest_results->dims->data[1] != kCategoryCount)) {
|
||||||
|
error_reporter_->Report(
|
||||||
|
"The results for recognition should contain %d elements, but there are "
|
||||||
|
"%d in an %d-dimensional shape",
|
||||||
|
kCategoryCount, latest_results->dims->data[1],
|
||||||
|
latest_results->dims->size);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (latest_results->type != kTfLiteUInt8) {
|
||||||
|
error_reporter_->Report(
|
||||||
|
"The results for recognition should be uint8 elements, but are %d",
|
||||||
|
latest_results->type);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((!previous_results_.empty()) &&
|
||||||
|
(current_time_ms < previous_results_.front().time_)) {
|
||||||
|
error_reporter_->Report(
|
||||||
|
"Results must be fed in increasing time order, but received a "
|
||||||
|
"timestamp of %d that was earlier than the previous one of %d",
|
||||||
|
current_time_ms, previous_results_.front().time_);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the latest results to the head of the queue.
|
||||||
|
previous_results_.push_back({current_time_ms, latest_results->data.uint8});
|
||||||
|
|
||||||
|
// Prune any earlier results that are too old for the averaging window.
|
||||||
|
const int64_t time_limit = current_time_ms - average_window_duration_ms_;
|
||||||
|
while ((!previous_results_.empty()) &&
|
||||||
|
previous_results_.front().time_ < time_limit) {
|
||||||
|
previous_results_.pop_front();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are too few results, assume the result will be unreliable and
|
||||||
|
// bail.
|
||||||
|
const int64_t how_many_results = previous_results_.size();
|
||||||
|
const int64_t earliest_time = previous_results_.front().time_;
|
||||||
|
const int64_t samples_duration = current_time_ms - earliest_time;
|
||||||
|
if ((how_many_results < minimum_count_) ||
|
||||||
|
(samples_duration < (average_window_duration_ms_ / 4))) {
|
||||||
|
*found_command = previous_top_label_;
|
||||||
|
*score = 0;
|
||||||
|
*is_new_command = false;
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the average score across all the results in the window.
|
||||||
|
int32_t average_scores[kCategoryCount];
|
||||||
|
for (int offset = 0; offset < previous_results_.size(); ++offset) {
|
||||||
|
PreviousResultsQueue::Result previous_result =
|
||||||
|
previous_results_.from_front(offset);
|
||||||
|
const uint8_t* scores = previous_result.scores_;
|
||||||
|
for (int i = 0; i < kCategoryCount; ++i) {
|
||||||
|
if (offset == 0) {
|
||||||
|
average_scores[i] = scores[i];
|
||||||
|
} else {
|
||||||
|
average_scores[i] += scores[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < kCategoryCount; ++i) {
|
||||||
|
average_scores[i] /= how_many_results;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the current highest scoring category.
|
||||||
|
int current_top_index = 0;
|
||||||
|
int32_t current_top_score = 0;
|
||||||
|
for (int i = 0; i < kCategoryCount; ++i) {
|
||||||
|
if (average_scores[i] > current_top_score) {
|
||||||
|
current_top_score = average_scores[i];
|
||||||
|
current_top_index = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const char* current_top_label = kCategoryLabels[current_top_index];
|
||||||
|
|
||||||
|
// If we've recently had another label trigger, assume one that occurs too
|
||||||
|
// soon afterwards is a bad result.
|
||||||
|
int64_t time_since_last_top;
|
||||||
|
if ((previous_top_label_ == kCategoryLabels[0]) ||
|
||||||
|
(previous_top_label_time_ == std::numeric_limits<int32_t>::min())) {
|
||||||
|
time_since_last_top = std::numeric_limits<int32_t>::max();
|
||||||
|
} else {
|
||||||
|
time_since_last_top = current_time_ms - previous_top_label_time_;
|
||||||
|
}
|
||||||
|
if ((current_top_score > detection_threshold_) &&
|
||||||
|
(current_top_label != previous_top_label_) &&
|
||||||
|
(time_since_last_top > suppression_ms_)) {
|
||||||
|
previous_top_label_ = current_top_label;
|
||||||
|
previous_top_label_time_ = current_time_ms;
|
||||||
|
*is_new_command = true;
|
||||||
|
} else {
|
||||||
|
*is_new_command = false;
|
||||||
|
}
|
||||||
|
*found_command = current_top_label;
|
||||||
|
*score = current_top_score;
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
@ -0,0 +1,158 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
|
||||||
|
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
|
||||||
|
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
|
||||||
|
|
||||||
|
// Partial implementation of std::dequeue, just providing the functionality
|
||||||
|
// that's needed to keep a record of previous neural network results over a
|
||||||
|
// short time period, so they can be averaged together to produce a more
|
||||||
|
// accurate overall prediction. This doesn't use any dynamic memory allocation
|
||||||
|
// so it's a better fit for microcontroller applications, but this does mean
|
||||||
|
// there are hard limits on the number of results it can store.
|
||||||
|
class PreviousResultsQueue {
|
||||||
|
public:
|
||||||
|
PreviousResultsQueue(tflite::ErrorReporter* error_reporter)
|
||||||
|
: error_reporter_(error_reporter), front_index_(0), size_(0) {}
|
||||||
|
|
||||||
|
// Data structure that holds an inference result, and the time when it
|
||||||
|
// was recorded.
|
||||||
|
struct Result {
|
||||||
|
Result() : time_(0), scores_() {}
|
||||||
|
Result(int32_t time, uint8_t* scores) : time_(time) {
|
||||||
|
for (int i = 0; i < kCategoryCount; ++i) {
|
||||||
|
scores_[i] = scores[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int32_t time_;
|
||||||
|
uint8_t scores_[kCategoryCount];
|
||||||
|
};
|
||||||
|
|
||||||
|
int size() { return size_; }
|
||||||
|
bool empty() { return size_ == 0; }
|
||||||
|
Result& front() { return results_[front_index_]; }
|
||||||
|
Result& back() {
|
||||||
|
int back_index = front_index_ + (size_ - 1);
|
||||||
|
if (back_index >= kMaxResults) {
|
||||||
|
back_index -= kMaxResults;
|
||||||
|
}
|
||||||
|
return results_[back_index];
|
||||||
|
}
|
||||||
|
|
||||||
|
void push_back(const Result& entry) {
|
||||||
|
if (size() >= kMaxResults) {
|
||||||
|
error_reporter_->Report(
|
||||||
|
"Couldn't push_back latest result, too many already!");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
size_ += 1;
|
||||||
|
back() = entry;
|
||||||
|
}
|
||||||
|
|
||||||
|
Result pop_front() {
|
||||||
|
if (size() <= 0) {
|
||||||
|
error_reporter_->Report("Couldn't pop_front result, none present!");
|
||||||
|
return Result();
|
||||||
|
}
|
||||||
|
Result result = front();
|
||||||
|
front_index_ += 1;
|
||||||
|
if (front_index_ >= kMaxResults) {
|
||||||
|
front_index_ = 0;
|
||||||
|
}
|
||||||
|
size_ -= 1;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Most of the functions are duplicates of dequeue containers, but this
|
||||||
|
// is a helper that makes it easy to iterate through the contents of the
|
||||||
|
// queue.
|
||||||
|
Result& from_front(int offset) {
|
||||||
|
if ((offset < 0) || (offset >= size_)) {
|
||||||
|
error_reporter_->Report("Attempt to read beyond the end of the queue!");
|
||||||
|
offset = size_ - 1;
|
||||||
|
}
|
||||||
|
int index = front_index_ + offset;
|
||||||
|
if (index >= kMaxResults) {
|
||||||
|
index -= kMaxResults;
|
||||||
|
}
|
||||||
|
return results_[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
tflite::ErrorReporter* error_reporter_;
|
||||||
|
static constexpr int kMaxResults = 50;
|
||||||
|
Result results_[kMaxResults];
|
||||||
|
|
||||||
|
int front_index_;
|
||||||
|
int size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// This class is designed to apply a very primitive decoding model on top of the
|
||||||
|
// instantaneous results from running an audio recognition model on a single
|
||||||
|
// window of samples. It applies smoothing over time so that noisy individual
|
||||||
|
// label scores are averaged, increasing the confidence that apparent matches
|
||||||
|
// are real.
|
||||||
|
// To use it, you should create a class object with the configuration you
|
||||||
|
// want, and then feed results from running a TensorFlow model into the
|
||||||
|
// processing method. The timestamp for each subsequent call should be
|
||||||
|
// increasing from the previous, since the class is designed to process a stream
|
||||||
|
// of data over time.
|
||||||
|
class RecognizeCommands {
|
||||||
|
public:
|
||||||
|
// labels should be a list of the strings associated with each one-hot score.
|
||||||
|
// The window duration controls the smoothing. Longer durations will give a
|
||||||
|
// higher confidence that the results are correct, but may miss some commands.
|
||||||
|
// The detection threshold has a similar effect, with high values increasing
|
||||||
|
// the precision at the cost of recall. The minimum count controls how many
|
||||||
|
// results need to be in the averaging window before it's seen as a reliable
|
||||||
|
// average. This prevents erroneous results when the averaging window is
|
||||||
|
// initially being populated for example. The suppression argument disables
|
||||||
|
// further recognitions for a set time after one has been triggered, which can
|
||||||
|
// help reduce spurious recognitions.
|
||||||
|
explicit RecognizeCommands(tflite::ErrorReporter* error_reporter,
|
||||||
|
int32_t average_window_duration_ms = 1000,
|
||||||
|
uint8_t detection_threshold = 51,
|
||||||
|
int32_t suppression_ms = 500,
|
||||||
|
int32_t minimum_count = 3);
|
||||||
|
|
||||||
|
// Call this with the results of running a model on sample data.
|
||||||
|
TfLiteStatus ProcessLatestResults(const TfLiteTensor* latest_results,
|
||||||
|
const int32_t current_time_ms,
|
||||||
|
const char** found_command, uint8_t* score,
|
||||||
|
bool* is_new_command);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Configuration
|
||||||
|
tflite::ErrorReporter* error_reporter_;
|
||||||
|
int32_t average_window_duration_ms_;
|
||||||
|
uint8_t detection_threshold_;
|
||||||
|
int32_t suppression_ms_;
|
||||||
|
int32_t minimum_count_;
|
||||||
|
|
||||||
|
// Working variables
|
||||||
|
PreviousResultsQueue previous_results_;
|
||||||
|
int previous_results_head_;
|
||||||
|
int previous_results_tail_;
|
||||||
|
const char* previous_top_label_;
|
||||||
|
int32_t previous_top_label_time_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
|
@ -0,0 +1,207 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||||
|
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TESTS_BEGIN
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(PreviousResultsQueueBasic) {
|
||||||
|
tflite::MicroErrorReporter micro_error_reporter;
|
||||||
|
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||||
|
|
||||||
|
PreviousResultsQueue queue(error_reporter);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, queue.size());
|
||||||
|
|
||||||
|
uint8_t scores_a[4] = {0, 0, 0, 1};
|
||||||
|
queue.push_back({0, scores_a});
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(1, queue.size());
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, queue.back().time_);
|
||||||
|
|
||||||
|
uint8_t scores_b[4] = {0, 0, 1, 0};
|
||||||
|
queue.push_back({1, scores_b});
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(2, queue.size());
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(1, queue.back().time_);
|
||||||
|
|
||||||
|
PreviousResultsQueue::Result pop_result = queue.pop_front();
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, pop_result.time_);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(1, queue.size());
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(1, queue.back().time_);
|
||||||
|
|
||||||
|
uint8_t scores_c[4] = {0, 1, 0, 0};
|
||||||
|
queue.push_back({2, scores_c});
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(2, queue.size());
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(2, queue.back().time_);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(PreviousResultsQueuePushPop) {
|
||||||
|
tflite::MicroErrorReporter micro_error_reporter;
|
||||||
|
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||||
|
|
||||||
|
PreviousResultsQueue queue(error_reporter);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, queue.size());
|
||||||
|
|
||||||
|
for (int i = 0; i < 123; ++i) {
|
||||||
|
uint8_t scores[4] = {0, 0, 0, 1};
|
||||||
|
queue.push_back({i, scores});
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(1, queue.size());
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(i, queue.front().time_);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(i, queue.back().time_);
|
||||||
|
|
||||||
|
PreviousResultsQueue::Result pop_result = queue.pop_front();
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(i, pop_result.time_);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, queue.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(RecognizeCommandsTestBasic) {
|
||||||
|
tflite::MicroErrorReporter micro_error_reporter;
|
||||||
|
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||||
|
|
||||||
|
RecognizeCommands recognize_commands(error_reporter);
|
||||||
|
|
||||||
|
TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
|
||||||
|
{255, 0, 0, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||||
|
"input_tensor", 0.0f, 128.0f);
|
||||||
|
|
||||||
|
const char* found_command;
|
||||||
|
uint8_t score;
|
||||||
|
bool is_new_command;
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(
|
||||||
|
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||||
|
&results, 0, &found_command, &score, &is_new_command));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(RecognizeCommandsTestFindCommands) {
|
||||||
|
tflite::MicroErrorReporter micro_error_reporter;
|
||||||
|
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||||
|
|
||||||
|
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
|
||||||
|
|
||||||
|
TfLiteTensor yes_results = tflite::testing::CreateQuantizedTensor(
|
||||||
|
{0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||||
|
"input_tensor", 0.0f, 128.0f);
|
||||||
|
|
||||||
|
bool has_found_new_command = false;
|
||||||
|
const char* new_command;
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
const char* found_command;
|
||||||
|
uint8_t score;
|
||||||
|
bool is_new_command;
|
||||||
|
int32_t current_time_ms = 0 + (i * 100);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(
|
||||||
|
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||||
|
&yes_results, current_time_ms, &found_command, &score,
|
||||||
|
&is_new_command));
|
||||||
|
if (is_new_command) {
|
||||||
|
TF_LITE_MICRO_EXPECT(!has_found_new_command);
|
||||||
|
has_found_new_command = true;
|
||||||
|
new_command = found_command;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TF_LITE_MICRO_EXPECT(has_found_new_command);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("yes", new_command));
|
||||||
|
|
||||||
|
TfLiteTensor no_results = tflite::testing::CreateQuantizedTensor(
|
||||||
|
{0, 0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||||
|
"input_tensor", 0.0f, 128.0f);
|
||||||
|
has_found_new_command = false;
|
||||||
|
new_command = "";
|
||||||
|
uint8_t score;
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
const char* found_command;
|
||||||
|
bool is_new_command;
|
||||||
|
int32_t current_time_ms = 1000 + (i * 100);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(
|
||||||
|
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||||
|
&no_results, current_time_ms, &found_command, &score,
|
||||||
|
&is_new_command));
|
||||||
|
if (is_new_command) {
|
||||||
|
TF_LITE_MICRO_EXPECT(!has_found_new_command);
|
||||||
|
has_found_new_command = true;
|
||||||
|
new_command = found_command;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TF_LITE_MICRO_EXPECT(has_found_new_command);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(231, score);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("no", new_command));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputLength) {
|
||||||
|
tflite::MicroErrorReporter micro_error_reporter;
|
||||||
|
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||||
|
|
||||||
|
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
|
||||||
|
|
||||||
|
TfLiteTensor bad_results = tflite::testing::CreateQuantizedTensor(
|
||||||
|
{0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 3}),
|
||||||
|
"input_tensor", 0.0f, 128.0f);
|
||||||
|
|
||||||
|
const char* found_command;
|
||||||
|
uint8_t score;
|
||||||
|
bool is_new_command;
|
||||||
|
TF_LITE_MICRO_EXPECT_NE(
|
||||||
|
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||||
|
&bad_results, 0, &found_command, &score, &is_new_command));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputTimes) {
|
||||||
|
tflite::MicroErrorReporter micro_error_reporter;
|
||||||
|
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||||
|
|
||||||
|
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
|
||||||
|
|
||||||
|
TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
|
||||||
|
{0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||||
|
"input_tensor", 0.0f, 128.0f);
|
||||||
|
|
||||||
|
const char* found_command;
|
||||||
|
uint8_t score;
|
||||||
|
bool is_new_command;
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(
|
||||||
|
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||||
|
&results, 100, &found_command, &score, &is_new_command));
|
||||||
|
TF_LITE_MICRO_EXPECT_NE(
|
||||||
|
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||||
|
&results, 0, &found_command, &score, &is_new_command));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(RecognizeCommandsTestTooFewInputs) {
|
||||||
|
tflite::MicroErrorReporter micro_error_reporter;
|
||||||
|
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||||
|
|
||||||
|
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
|
||||||
|
|
||||||
|
TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
|
||||||
|
{0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||||
|
"input_tensor", 0.0f, 128.0f);
|
||||||
|
|
||||||
|
const char* found_command;
|
||||||
|
uint8_t score;
|
||||||
|
bool is_new_command;
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(
|
||||||
|
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||||
|
&results, 100, &found_command, &score, &is_new_command));
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(0, score);
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(false, is_new_command);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TESTS_END
|
@ -48,22 +48,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "test_utils",
|
|
||||||
srcs = [
|
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"test_utils.h",
|
|
||||||
],
|
|
||||||
copts = tflite_copts(),
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
|
||||||
"//tensorflow/lite/core/api",
|
|
||||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
|
||||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tflite_micro_cc_test(
|
tflite_micro_cc_test(
|
||||||
name = "depthwise_conv_test",
|
name = "depthwise_conv_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -71,7 +55,6 @@ tflite_micro_cc_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":all_ops_resolver",
|
":all_ops_resolver",
|
||||||
":test_utils",
|
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||||
@ -85,7 +68,6 @@ tflite_micro_cc_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":all_ops_resolver",
|
":all_ops_resolver",
|
||||||
":test_utils",
|
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||||
@ -99,7 +81,6 @@ tflite_micro_cc_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":all_ops_resolver",
|
":all_ops_resolver",
|
||||||
":test_utils",
|
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||||
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
|
|
||||||
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
||||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||||
|
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace testing {
|
namespace testing {
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||||
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
|
|
||||||
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
||||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||||
|
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace testing {
|
namespace testing {
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||||
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
|
|
||||||
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
||||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||||
|
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace testing {
|
namespace testing {
|
||||||
|
@ -10,8 +10,10 @@ cc_library(
|
|||||||
name = "micro_test",
|
name = "micro_test",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"micro_test.h",
|
"micro_test.h",
|
||||||
|
"test_utils.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
|
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_
|
||||||
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
|
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_
|
||||||
|
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
@ -21,8 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
|
||||||
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
|
|
||||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -164,7 +163,20 @@ inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list<int32_t> data,
|
|||||||
return CreateQuantized32Tensor(data.begin(), dims, name, min, max);
|
return CreateQuantized32Tensor(data.begin(), dims, name, min, max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do a simple string comparison for testing purposes, without requiring the
|
||||||
|
// standard C library.
|
||||||
|
inline int TestStrcmp(const char* a, const char* b) {
|
||||||
|
if ((a == nullptr) || (b == nullptr)) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
while ((*a != 0) && (*a == *b)) {
|
||||||
|
a++;
|
||||||
|
b++;
|
||||||
|
}
|
||||||
|
return *(const unsigned char*)a - *(const unsigned char*)b;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
|
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_
|
Loading…
Reference in New Issue
Block a user