Integrate speech commands results over time to give more accurate predictions
PiperOrigin-RevId: 226948751
This commit is contained in:
parent
83cb1f1c5e
commit
6d92ee85a8
@ -176,7 +176,6 @@ cc_library(
|
||||
":audio_provider",
|
||||
":model_settings",
|
||||
":preprocessor_reference",
|
||||
":timer",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
],
|
||||
@ -191,7 +190,6 @@ tflite_micro_cc_test(
|
||||
":audio_provider",
|
||||
":feature_provider",
|
||||
":model_settings",
|
||||
":timer",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
@ -221,6 +219,34 @@ tflite_micro_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "recognize_commands",
|
||||
srcs = [
|
||||
"recognize_commands.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"recognize_commands.h",
|
||||
],
|
||||
deps = [
|
||||
":model_settings",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "recognize_commands_test",
|
||||
srcs = [
|
||||
"recognize_commands_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":recognize_commands",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "micro_speech",
|
||||
srcs = [
|
||||
@ -232,6 +258,7 @@ cc_binary(
|
||||
":features_test_data",
|
||||
":model_settings",
|
||||
":preprocessor_reference",
|
||||
":recognize_commands",
|
||||
":timer",
|
||||
":tiny_conv_model_data",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
|
@ -91,7 +91,6 @@ FEATURE_PROVIDER_TEST_SRCS := \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc
|
||||
ALL_SRCS += $(FEATURE_PROVIDER_TEST_SRCS)
|
||||
@ -128,6 +127,26 @@ timer_test_bin: $(TIMER_TEST_BINARY).bin
|
||||
test_timer: $(TIMER_TEST_BINARY)
|
||||
$(TEST_SCRIPT) $(TIMER_TEST_BINARY) '~~~ALL TESTS PASSED~~~'
|
||||
|
||||
# Tests the feature provider module.
|
||||
RECOGNIZE_COMMANDS_TEST_SRCS := \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc
|
||||
ALL_SRCS += $(RECOGNIZE_COMMANDS_TEST_SRCS)
|
||||
RECOGNIZE_COMMANDS_TEST_OBJS := $(addprefix $(OBJDIR), \
|
||||
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(RECOGNIZE_COMMANDS_TEST_SRCS))))
|
||||
RECOGNIZE_COMMANDS_TEST_BINARY := $(BINDIR)recognize_commands_test
|
||||
ALL_BINARIES += $(RECOGNIZE_COMMANDS_TEST_BINARY)
|
||||
$(RECOGNIZE_COMMANDS_TEST_BINARY): $(RECOGNIZE_COMMANDS_TEST_OBJS) $(MICROLITE_LIB_PATH)
|
||||
@mkdir -p $(dir $@)
|
||||
$(CXX) $(CXXFLAGS) $(INCLUDES) \
|
||||
-o $(RECOGNIZE_COMMANDS_TEST_BINARY) $(RECOGNIZE_COMMANDS_TEST_OBJS) \
|
||||
$(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS)
|
||||
recognize_commands_test: $(RECOGNIZE_COMMANDS_TEST_BINARY)
|
||||
recognize_commands_test_bin: $(RECOGNIZE_COMMANDS_TEST_BINARY).bin
|
||||
test_recognize_commands: $(RECOGNIZE_COMMANDS_TEST_BINARY)
|
||||
$(TEST_SCRIPT) $(RECOGNIZE_COMMANDS_TEST_BINARY) '~~~ALL TESTS PASSED~~~'
|
||||
|
||||
# Builds a standalone speech command recognizer binary.
|
||||
MICRO_SPEECH_SRCS := \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/main.cc \
|
||||
@ -138,7 +157,8 @@ tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc \
|
||||
tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc
|
||||
ALL_SRCS += $(MICRO_SPEECH_SRCS)
|
||||
MICRO_SPEECH_OBJS := $(addprefix $(OBJDIR), \
|
||||
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_SRCS))))
|
||||
|
@ -18,20 +18,11 @@ limitations under the License.
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h"
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h"
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h"
|
||||
|
||||
namespace {
|
||||
// Stores the timestamp for the previous fetch of audio data, so that we can
|
||||
// avoid recalculating all the features from scratch if some earlier timeslices
|
||||
// are still present.
|
||||
int32_t g_last_time_in_ms = 0;
|
||||
// Make sure we don't try to use cached information if this is the first call
|
||||
// into the provider.
|
||||
bool g_is_first_run = true;
|
||||
} // namespace
|
||||
|
||||
FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data)
|
||||
: feature_size_(feature_size), feature_data_(feature_data) {
|
||||
: feature_size_(feature_size),
|
||||
feature_data_(feature_data),
|
||||
is_first_run_(true) {
|
||||
// Initialize the feature data to default values.
|
||||
for (int n = 0; n < feature_size_; ++n) {
|
||||
feature_data_[n] = 0;
|
||||
@ -41,24 +32,23 @@ FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data)
|
||||
FeatureProvider::~FeatureProvider() {}
|
||||
|
||||
TfLiteStatus FeatureProvider::PopulateFeatureData(
|
||||
tflite::ErrorReporter* error_reporter, int* how_many_new_slices) {
|
||||
tflite::ErrorReporter* error_reporter, int32_t last_time_in_ms,
|
||||
int32_t time_in_ms, int* how_many_new_slices) {
|
||||
if (feature_size_ != kFeatureElementCount) {
|
||||
error_reporter->Report("Requested feature_data_ size %d doesn't match %d",
|
||||
feature_size_, kFeatureElementCount);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
const int32_t time_in_ms = TimeInMilliseconds();
|
||||
// Quantize the time into steps as long as each window stride, so we can
|
||||
// figure out which audio data we need to fetch.
|
||||
const int last_step = (g_last_time_in_ms / kFeatureSliceStrideMs);
|
||||
const int last_step = (last_time_in_ms / kFeatureSliceStrideMs);
|
||||
const int current_step = (time_in_ms / kFeatureSliceStrideMs);
|
||||
g_last_time_in_ms = time_in_ms;
|
||||
|
||||
int slices_needed = current_step - last_step;
|
||||
// If this is the first call, make sure we don't use any cached information.
|
||||
if (g_is_first_run) {
|
||||
g_is_first_run = false;
|
||||
if (is_first_run_) {
|
||||
is_first_run_ = false;
|
||||
slices_needed = kFeatureSliceCount;
|
||||
}
|
||||
if (slices_needed > kFeatureSliceCount) {
|
||||
|
@ -38,11 +38,15 @@ class FeatureProvider {
|
||||
// Fills the feature data with information from audio inputs, and returns how
|
||||
// many feature slices were updated.
|
||||
TfLiteStatus PopulateFeatureData(tflite::ErrorReporter* error_reporter,
|
||||
int32_t last_time_in_ms, int32_t time_in_ms,
|
||||
int* how_many_new_slices);
|
||||
|
||||
private:
|
||||
int feature_size_;
|
||||
uint8_t* feature_data_;
|
||||
// Make sure we don't try to use cached information if this is the first call
|
||||
// into the provider.
|
||||
bool is_first_run_;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_
|
||||
|
@ -30,7 +30,8 @@ TF_LITE_MICRO_TEST(TestFeatureProvider) {
|
||||
|
||||
int how_many_new_slices = 0;
|
||||
TfLiteStatus populate_status = feature_provider.PopulateFeatureData(
|
||||
error_reporter, &how_many_new_slices);
|
||||
error_reporter, /* last_time_in_ms= */ 0, /* time_in_ms= */ 10000,
|
||||
&how_many_new_slices);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, populate_status);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kFeatureSliceCount, how_many_new_slices);
|
||||
}
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h"
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h"
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h"
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
|
||||
@ -68,16 +70,21 @@ int main(int argc, char* argv[]) {
|
||||
FeatureProvider feature_provider(kFeatureElementCount,
|
||||
model_input->data.uint8);
|
||||
|
||||
RecognizeCommands recognizer(error_reporter);
|
||||
|
||||
int32_t previous_time = 0;
|
||||
// Keep reading and analysing audio data in an infinite loop.
|
||||
while (true) {
|
||||
// Fetch the spectrogram for the current time.
|
||||
const int32_t current_time = TimeInMilliseconds();
|
||||
int how_many_new_slices = 0;
|
||||
TfLiteStatus feature_status = feature_provider.PopulateFeatureData(
|
||||
error_reporter, &how_many_new_slices);
|
||||
error_reporter, previous_time, current_time, &how_many_new_slices);
|
||||
if (feature_status != kTfLiteOk) {
|
||||
error_reporter->Report("Feature generation failed");
|
||||
return 1;
|
||||
}
|
||||
previous_time = current_time;
|
||||
// If no new audio samples have been received since last time, don't bother
|
||||
// running the network model.
|
||||
if (how_many_new_slices == 0) {
|
||||
@ -105,7 +112,19 @@ int main(int argc, char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
error_reporter->Report("Heard %s", kCategoryLabels[top_category_index]);
|
||||
const char* found_command = nullptr;
|
||||
uint8_t score = 0;
|
||||
bool is_new_command = false;
|
||||
TfLiteStatus process_status = recognizer.ProcessLatestResults(
|
||||
output, current_time, &found_command, &score, &is_new_command);
|
||||
if (process_status != kTfLiteOk) {
|
||||
error_reporter->Report(
|
||||
"RecognizeCommands::ProcessLatestResults() failed");
|
||||
return 1;
|
||||
}
|
||||
if (is_new_command) {
|
||||
error_reporter->Report("Heard %s (%d)", found_command, score);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
@ -0,0 +1,139 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
RecognizeCommands::RecognizeCommands(tflite::ErrorReporter* error_reporter,
|
||||
int32_t average_window_duration_ms,
|
||||
uint8_t detection_threshold,
|
||||
int32_t suppression_ms,
|
||||
int32_t minimum_count)
|
||||
: error_reporter_(error_reporter),
|
||||
average_window_duration_ms_(average_window_duration_ms),
|
||||
detection_threshold_(detection_threshold),
|
||||
suppression_ms_(suppression_ms),
|
||||
minimum_count_(minimum_count),
|
||||
previous_results_(error_reporter) {
|
||||
previous_top_label_ = "_silence_";
|
||||
previous_top_label_time_ = 0;
|
||||
}
|
||||
|
||||
TfLiteStatus RecognizeCommands::ProcessLatestResults(
|
||||
const TfLiteTensor* latest_results, const int32_t current_time_ms,
|
||||
const char** found_command, uint8_t* score, bool* is_new_command) {
|
||||
if ((latest_results->dims->size != 2) ||
|
||||
(latest_results->dims->data[0] != 1) ||
|
||||
(latest_results->dims->data[1] != kCategoryCount)) {
|
||||
error_reporter_->Report(
|
||||
"The results for recognition should contain %d elements, but there are "
|
||||
"%d in an %d-dimensional shape",
|
||||
kCategoryCount, latest_results->dims->data[1],
|
||||
latest_results->dims->size);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (latest_results->type != kTfLiteUInt8) {
|
||||
error_reporter_->Report(
|
||||
"The results for recognition should be uint8 elements, but are %d",
|
||||
latest_results->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if ((!previous_results_.empty()) &&
|
||||
(current_time_ms < previous_results_.front().time_)) {
|
||||
error_reporter_->Report(
|
||||
"Results must be fed in increasing time order, but received a "
|
||||
"timestamp of %d that was earlier than the previous one of %d",
|
||||
current_time_ms, previous_results_.front().time_);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Add the latest results to the head of the queue.
|
||||
previous_results_.push_back({current_time_ms, latest_results->data.uint8});
|
||||
|
||||
// Prune any earlier results that are too old for the averaging window.
|
||||
const int64_t time_limit = current_time_ms - average_window_duration_ms_;
|
||||
while ((!previous_results_.empty()) &&
|
||||
previous_results_.front().time_ < time_limit) {
|
||||
previous_results_.pop_front();
|
||||
}
|
||||
|
||||
// If there are too few results, assume the result will be unreliable and
|
||||
// bail.
|
||||
const int64_t how_many_results = previous_results_.size();
|
||||
const int64_t earliest_time = previous_results_.front().time_;
|
||||
const int64_t samples_duration = current_time_ms - earliest_time;
|
||||
if ((how_many_results < minimum_count_) ||
|
||||
(samples_duration < (average_window_duration_ms_ / 4))) {
|
||||
*found_command = previous_top_label_;
|
||||
*score = 0;
|
||||
*is_new_command = false;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Calculate the average score across all the results in the window.
|
||||
int32_t average_scores[kCategoryCount];
|
||||
for (int offset = 0; offset < previous_results_.size(); ++offset) {
|
||||
PreviousResultsQueue::Result previous_result =
|
||||
previous_results_.from_front(offset);
|
||||
const uint8_t* scores = previous_result.scores_;
|
||||
for (int i = 0; i < kCategoryCount; ++i) {
|
||||
if (offset == 0) {
|
||||
average_scores[i] = scores[i];
|
||||
} else {
|
||||
average_scores[i] += scores[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < kCategoryCount; ++i) {
|
||||
average_scores[i] /= how_many_results;
|
||||
}
|
||||
|
||||
// Find the current highest scoring category.
|
||||
int current_top_index = 0;
|
||||
int32_t current_top_score = 0;
|
||||
for (int i = 0; i < kCategoryCount; ++i) {
|
||||
if (average_scores[i] > current_top_score) {
|
||||
current_top_score = average_scores[i];
|
||||
current_top_index = i;
|
||||
}
|
||||
}
|
||||
const char* current_top_label = kCategoryLabels[current_top_index];
|
||||
|
||||
// If we've recently had another label trigger, assume one that occurs too
|
||||
// soon afterwards is a bad result.
|
||||
int64_t time_since_last_top;
|
||||
if ((previous_top_label_ == kCategoryLabels[0]) ||
|
||||
(previous_top_label_time_ == std::numeric_limits<int32_t>::min())) {
|
||||
time_since_last_top = std::numeric_limits<int32_t>::max();
|
||||
} else {
|
||||
time_since_last_top = current_time_ms - previous_top_label_time_;
|
||||
}
|
||||
if ((current_top_score > detection_threshold_) &&
|
||||
(current_top_label != previous_top_label_) &&
|
||||
(time_since_last_top > suppression_ms_)) {
|
||||
previous_top_label_ = current_top_label;
|
||||
previous_top_label_time_ = current_time_ms;
|
||||
*is_new_command = true;
|
||||
} else {
|
||||
*is_new_command = false;
|
||||
}
|
||||
*found_command = current_top_label;
|
||||
*score = current_top_score;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
@ -0,0 +1,158 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
|
||||
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
|
||||
|
||||
// Partial implementation of std::dequeue, just providing the functionality
|
||||
// that's needed to keep a record of previous neural network results over a
|
||||
// short time period, so they can be averaged together to produce a more
|
||||
// accurate overall prediction. This doesn't use any dynamic memory allocation
|
||||
// so it's a better fit for microcontroller applications, but this does mean
|
||||
// there are hard limits on the number of results it can store.
|
||||
class PreviousResultsQueue {
|
||||
public:
|
||||
PreviousResultsQueue(tflite::ErrorReporter* error_reporter)
|
||||
: error_reporter_(error_reporter), front_index_(0), size_(0) {}
|
||||
|
||||
// Data structure that holds an inference result, and the time when it
|
||||
// was recorded.
|
||||
struct Result {
|
||||
Result() : time_(0), scores_() {}
|
||||
Result(int32_t time, uint8_t* scores) : time_(time) {
|
||||
for (int i = 0; i < kCategoryCount; ++i) {
|
||||
scores_[i] = scores[i];
|
||||
}
|
||||
}
|
||||
int32_t time_;
|
||||
uint8_t scores_[kCategoryCount];
|
||||
};
|
||||
|
||||
int size() { return size_; }
|
||||
bool empty() { return size_ == 0; }
|
||||
Result& front() { return results_[front_index_]; }
|
||||
Result& back() {
|
||||
int back_index = front_index_ + (size_ - 1);
|
||||
if (back_index >= kMaxResults) {
|
||||
back_index -= kMaxResults;
|
||||
}
|
||||
return results_[back_index];
|
||||
}
|
||||
|
||||
void push_back(const Result& entry) {
|
||||
if (size() >= kMaxResults) {
|
||||
error_reporter_->Report(
|
||||
"Couldn't push_back latest result, too many already!");
|
||||
return;
|
||||
}
|
||||
size_ += 1;
|
||||
back() = entry;
|
||||
}
|
||||
|
||||
Result pop_front() {
|
||||
if (size() <= 0) {
|
||||
error_reporter_->Report("Couldn't pop_front result, none present!");
|
||||
return Result();
|
||||
}
|
||||
Result result = front();
|
||||
front_index_ += 1;
|
||||
if (front_index_ >= kMaxResults) {
|
||||
front_index_ = 0;
|
||||
}
|
||||
size_ -= 1;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Most of the functions are duplicates of dequeue containers, but this
|
||||
// is a helper that makes it easy to iterate through the contents of the
|
||||
// queue.
|
||||
Result& from_front(int offset) {
|
||||
if ((offset < 0) || (offset >= size_)) {
|
||||
error_reporter_->Report("Attempt to read beyond the end of the queue!");
|
||||
offset = size_ - 1;
|
||||
}
|
||||
int index = front_index_ + offset;
|
||||
if (index >= kMaxResults) {
|
||||
index -= kMaxResults;
|
||||
}
|
||||
return results_[index];
|
||||
}
|
||||
|
||||
private:
|
||||
tflite::ErrorReporter* error_reporter_;
|
||||
static constexpr int kMaxResults = 50;
|
||||
Result results_[kMaxResults];
|
||||
|
||||
int front_index_;
|
||||
int size_;
|
||||
};
|
||||
|
||||
// This class is designed to apply a very primitive decoding model on top of the
|
||||
// instantaneous results from running an audio recognition model on a single
|
||||
// window of samples. It applies smoothing over time so that noisy individual
|
||||
// label scores are averaged, increasing the confidence that apparent matches
|
||||
// are real.
|
||||
// To use it, you should create a class object with the configuration you
|
||||
// want, and then feed results from running a TensorFlow model into the
|
||||
// processing method. The timestamp for each subsequent call should be
|
||||
// increasing from the previous, since the class is designed to process a stream
|
||||
// of data over time.
|
||||
class RecognizeCommands {
|
||||
public:
|
||||
// labels should be a list of the strings associated with each one-hot score.
|
||||
// The window duration controls the smoothing. Longer durations will give a
|
||||
// higher confidence that the results are correct, but may miss some commands.
|
||||
// The detection threshold has a similar effect, with high values increasing
|
||||
// the precision at the cost of recall. The minimum count controls how many
|
||||
// results need to be in the averaging window before it's seen as a reliable
|
||||
// average. This prevents erroneous results when the averaging window is
|
||||
// initially being populated for example. The suppression argument disables
|
||||
// further recognitions for a set time after one has been triggered, which can
|
||||
// help reduce spurious recognitions.
|
||||
explicit RecognizeCommands(tflite::ErrorReporter* error_reporter,
|
||||
int32_t average_window_duration_ms = 1000,
|
||||
uint8_t detection_threshold = 51,
|
||||
int32_t suppression_ms = 500,
|
||||
int32_t minimum_count = 3);
|
||||
|
||||
// Call this with the results of running a model on sample data.
|
||||
TfLiteStatus ProcessLatestResults(const TfLiteTensor* latest_results,
|
||||
const int32_t current_time_ms,
|
||||
const char** found_command, uint8_t* score,
|
||||
bool* is_new_command);
|
||||
|
||||
private:
|
||||
// Configuration
|
||||
tflite::ErrorReporter* error_reporter_;
|
||||
int32_t average_window_duration_ms_;
|
||||
uint8_t detection_threshold_;
|
||||
int32_t suppression_ms_;
|
||||
int32_t minimum_count_;
|
||||
|
||||
// Working variables
|
||||
PreviousResultsQueue previous_results_;
|
||||
int previous_results_head_;
|
||||
int previous_results_tail_;
|
||||
const char* previous_top_label_;
|
||||
int32_t previous_top_label_time_;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
|
@ -0,0 +1,207 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h"
|
||||
|
||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(PreviousResultsQueueBasic) {
|
||||
tflite::MicroErrorReporter micro_error_reporter;
|
||||
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||
|
||||
PreviousResultsQueue queue(error_reporter);
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, queue.size());
|
||||
|
||||
uint8_t scores_a[4] = {0, 0, 0, 1};
|
||||
queue.push_back({0, scores_a});
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, queue.size());
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_);
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, queue.back().time_);
|
||||
|
||||
uint8_t scores_b[4] = {0, 0, 1, 0};
|
||||
queue.push_back({1, scores_b});
|
||||
TF_LITE_MICRO_EXPECT_EQ(2, queue.size());
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_);
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, queue.back().time_);
|
||||
|
||||
PreviousResultsQueue::Result pop_result = queue.pop_front();
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, pop_result.time_);
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, queue.size());
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_);
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, queue.back().time_);
|
||||
|
||||
uint8_t scores_c[4] = {0, 1, 0, 0};
|
||||
queue.push_back({2, scores_c});
|
||||
TF_LITE_MICRO_EXPECT_EQ(2, queue.size());
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_);
|
||||
TF_LITE_MICRO_EXPECT_EQ(2, queue.back().time_);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(PreviousResultsQueuePushPop) {
|
||||
tflite::MicroErrorReporter micro_error_reporter;
|
||||
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||
|
||||
PreviousResultsQueue queue(error_reporter);
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, queue.size());
|
||||
|
||||
for (int i = 0; i < 123; ++i) {
|
||||
uint8_t scores[4] = {0, 0, 0, 1};
|
||||
queue.push_back({i, scores});
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, queue.size());
|
||||
TF_LITE_MICRO_EXPECT_EQ(i, queue.front().time_);
|
||||
TF_LITE_MICRO_EXPECT_EQ(i, queue.back().time_);
|
||||
|
||||
PreviousResultsQueue::Result pop_result = queue.pop_front();
|
||||
TF_LITE_MICRO_EXPECT_EQ(i, pop_result.time_);
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, queue.size());
|
||||
}
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(RecognizeCommandsTestBasic) {
|
||||
tflite::MicroErrorReporter micro_error_reporter;
|
||||
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||
|
||||
RecognizeCommands recognize_commands(error_reporter);
|
||||
|
||||
TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
|
||||
{255, 0, 0, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||
"input_tensor", 0.0f, 128.0f);
|
||||
|
||||
const char* found_command;
|
||||
uint8_t score;
|
||||
bool is_new_command;
|
||||
TF_LITE_MICRO_EXPECT_EQ(
|
||||
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||
&results, 0, &found_command, &score, &is_new_command));
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(RecognizeCommandsTestFindCommands) {
|
||||
tflite::MicroErrorReporter micro_error_reporter;
|
||||
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||
|
||||
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
|
||||
|
||||
TfLiteTensor yes_results = tflite::testing::CreateQuantizedTensor(
|
||||
{0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||
"input_tensor", 0.0f, 128.0f);
|
||||
|
||||
bool has_found_new_command = false;
|
||||
const char* new_command;
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
const char* found_command;
|
||||
uint8_t score;
|
||||
bool is_new_command;
|
||||
int32_t current_time_ms = 0 + (i * 100);
|
||||
TF_LITE_MICRO_EXPECT_EQ(
|
||||
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||
&yes_results, current_time_ms, &found_command, &score,
|
||||
&is_new_command));
|
||||
if (is_new_command) {
|
||||
TF_LITE_MICRO_EXPECT(!has_found_new_command);
|
||||
has_found_new_command = true;
|
||||
new_command = found_command;
|
||||
}
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT(has_found_new_command);
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("yes", new_command));
|
||||
|
||||
TfLiteTensor no_results = tflite::testing::CreateQuantizedTensor(
|
||||
{0, 0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||
"input_tensor", 0.0f, 128.0f);
|
||||
has_found_new_command = false;
|
||||
new_command = "";
|
||||
uint8_t score;
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
const char* found_command;
|
||||
bool is_new_command;
|
||||
int32_t current_time_ms = 1000 + (i * 100);
|
||||
TF_LITE_MICRO_EXPECT_EQ(
|
||||
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||
&no_results, current_time_ms, &found_command, &score,
|
||||
&is_new_command));
|
||||
if (is_new_command) {
|
||||
TF_LITE_MICRO_EXPECT(!has_found_new_command);
|
||||
has_found_new_command = true;
|
||||
new_command = found_command;
|
||||
}
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT(has_found_new_command);
|
||||
TF_LITE_MICRO_EXPECT_EQ(231, score);
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("no", new_command));
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputLength) {
|
||||
tflite::MicroErrorReporter micro_error_reporter;
|
||||
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||
|
||||
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
|
||||
|
||||
TfLiteTensor bad_results = tflite::testing::CreateQuantizedTensor(
|
||||
{0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 3}),
|
||||
"input_tensor", 0.0f, 128.0f);
|
||||
|
||||
const char* found_command;
|
||||
uint8_t score;
|
||||
bool is_new_command;
|
||||
TF_LITE_MICRO_EXPECT_NE(
|
||||
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||
&bad_results, 0, &found_command, &score, &is_new_command));
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputTimes) {
|
||||
tflite::MicroErrorReporter micro_error_reporter;
|
||||
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||
|
||||
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
|
||||
|
||||
TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
|
||||
{0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||
"input_tensor", 0.0f, 128.0f);
|
||||
|
||||
const char* found_command;
|
||||
uint8_t score;
|
||||
bool is_new_command;
|
||||
TF_LITE_MICRO_EXPECT_EQ(
|
||||
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||
&results, 100, &found_command, &score, &is_new_command));
|
||||
TF_LITE_MICRO_EXPECT_NE(
|
||||
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||
&results, 0, &found_command, &score, &is_new_command));
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(RecognizeCommandsTestTooFewInputs) {
|
||||
tflite::MicroErrorReporter micro_error_reporter;
|
||||
tflite::ErrorReporter* error_reporter = µ_error_reporter;
|
||||
|
||||
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
|
||||
|
||||
TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
|
||||
{0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
|
||||
"input_tensor", 0.0f, 128.0f);
|
||||
|
||||
const char* found_command;
|
||||
uint8_t score;
|
||||
bool is_new_command;
|
||||
TF_LITE_MICRO_EXPECT_EQ(
|
||||
kTfLiteOk, recognize_commands.ProcessLatestResults(
|
||||
&results, 100, &found_command, &score, &is_new_command));
|
||||
TF_LITE_MICRO_EXPECT_EQ(0, score);
|
||||
TF_LITE_MICRO_EXPECT_EQ(false, is_new_command);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
@ -48,22 +48,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_utils",
|
||||
srcs = [
|
||||
],
|
||||
hdrs = [
|
||||
"test_utils.h",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "depthwise_conv_test",
|
||||
srcs = [
|
||||
@ -71,7 +55,6 @@ tflite_micro_cc_test(
|
||||
],
|
||||
deps = [
|
||||
":all_ops_resolver",
|
||||
":test_utils",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
@ -85,7 +68,6 @@ tflite_micro_cc_test(
|
||||
],
|
||||
deps = [
|
||||
":all_ops_resolver",
|
||||
":test_utils",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
@ -99,7 +81,6 @@ tflite_micro_cc_test(
|
||||
],
|
||||
deps = [
|
||||
":all_ops_resolver",
|
||||
":test_utils",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
|
||||
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
|
||||
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
|
||||
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
|
@ -10,8 +10,10 @@ cc_library(
|
||||
name = "micro_test",
|
||||
hdrs = [
|
||||
"micro_test.h",
|
||||
"test_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
],
|
||||
)
|
||||
|
@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_
|
||||
|
||||
#include <cstdarg>
|
||||
#include <initializer_list>
|
||||
@ -21,8 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
|
||||
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -164,7 +163,20 @@ inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list<int32_t> data,
|
||||
return CreateQuantized32Tensor(data.begin(), dims, name, min, max);
|
||||
}
|
||||
|
||||
// Do a simple string comparison for testing purposes, without requiring the
|
||||
// standard C library.
|
||||
inline int TestStrcmp(const char* a, const char* b) {
|
||||
if ((a == nullptr) || (b == nullptr)) {
|
||||
return -1;
|
||||
}
|
||||
while ((*a != 0) && (*a == *b)) {
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
return *(const unsigned char*)a - *(const unsigned char*)b;
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_
|
Loading…
Reference in New Issue
Block a user