diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD b/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD index 799b2e5a5dd..70eeac14585 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD @@ -176,7 +176,6 @@ cc_library( ":audio_provider", ":model_settings", ":preprocessor_reference", - ":timer", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", ], @@ -191,7 +190,6 @@ tflite_micro_cc_test( ":audio_provider", ":feature_provider", ":model_settings", - ":timer", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", "//tensorflow/lite/experimental/micro/testing:micro_test", @@ -221,6 +219,34 @@ tflite_micro_cc_test( ], ) +cc_library( + name = "recognize_commands", + srcs = [ + "recognize_commands.cc", + ], + hdrs = [ + "recognize_commands.h", + ], + deps = [ + ":model_settings", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + ], +) + +tflite_micro_cc_test( + name = "recognize_commands_test", + srcs = [ + "recognize_commands_test.cc", + ], + deps = [ + ":recognize_commands", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + cc_binary( name = "micro_speech", srcs = [ @@ -232,6 +258,7 @@ cc_binary( ":features_test_data", ":model_settings", ":preprocessor_reference", + ":recognize_commands", ":timer", ":tiny_conv_model_data", "//tensorflow/lite:schema_fbs_version", diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/Makefile.inc b/tensorflow/lite/experimental/micro/examples/micro_speech/Makefile.inc index 0e42329cade..cce6ea8402a 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/Makefile.inc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/Makefile.inc @@ -91,7 +91,6 @@ FEATURE_PROVIDER_TEST_SRCS := \ tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc ALL_SRCS += $(FEATURE_PROVIDER_TEST_SRCS) @@ -128,6 +127,26 @@ timer_test_bin: $(TIMER_TEST_BINARY).bin test_timer: $(TIMER_TEST_BINARY) $(TEST_SCRIPT) $(TIMER_TEST_BINARY) '~~~ALL TESTS PASSED~~~' +# Tests the feature provider module. +RECOGNIZE_COMMANDS_TEST_SRCS := \ +tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc +ALL_SRCS += $(RECOGNIZE_COMMANDS_TEST_SRCS) +RECOGNIZE_COMMANDS_TEST_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(RECOGNIZE_COMMANDS_TEST_SRCS)))) +RECOGNIZE_COMMANDS_TEST_BINARY := $(BINDIR)recognize_commands_test +ALL_BINARIES += $(RECOGNIZE_COMMANDS_TEST_BINARY) +$(RECOGNIZE_COMMANDS_TEST_BINARY): $(RECOGNIZE_COMMANDS_TEST_OBJS) $(MICROLITE_LIB_PATH) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(RECOGNIZE_COMMANDS_TEST_BINARY) $(RECOGNIZE_COMMANDS_TEST_OBJS) \ + $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS) +recognize_commands_test: $(RECOGNIZE_COMMANDS_TEST_BINARY) +recognize_commands_test_bin: $(RECOGNIZE_COMMANDS_TEST_BINARY).bin +test_recognize_commands: $(RECOGNIZE_COMMANDS_TEST_BINARY) + $(TEST_SCRIPT) $(RECOGNIZE_COMMANDS_TEST_BINARY) '~~~ALL TESTS PASSED~~~' + # Builds a standalone speech command recognizer binary. MICRO_SPEECH_SRCS := \ tensorflow/lite/experimental/micro/examples/micro_speech/main.cc \ @@ -138,7 +157,8 @@ tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc \ tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc \ -tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc +tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc \ +tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc ALL_SRCS += $(MICRO_SPEECH_SRCS) MICRO_SPEECH_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_SRCS)))) diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc index c4c52ac0ff3..7f9ece41dd3 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc @@ -18,20 +18,11 @@ limitations under the License. #include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h" #include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" #include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h" -#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h" - -namespace { -// Stores the timestamp for the previous fetch of audio data, so that we can -// avoid recalculating all the features from scratch if some earlier timeslices -// are still present. -int32_t g_last_time_in_ms = 0; -// Make sure we don't try to use cached information if this is the first call -// into the provider. -bool g_is_first_run = true; -} // namespace FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data) - : feature_size_(feature_size), feature_data_(feature_data) { + : feature_size_(feature_size), + feature_data_(feature_data), + is_first_run_(true) { // Initialize the feature data to default values. for (int n = 0; n < feature_size_; ++n) { feature_data_[n] = 0; @@ -41,24 +32,23 @@ FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data) FeatureProvider::~FeatureProvider() {} TfLiteStatus FeatureProvider::PopulateFeatureData( - tflite::ErrorReporter* error_reporter, int* how_many_new_slices) { + tflite::ErrorReporter* error_reporter, int32_t last_time_in_ms, + int32_t time_in_ms, int* how_many_new_slices) { if (feature_size_ != kFeatureElementCount) { error_reporter->Report("Requested feature_data_ size %d doesn't match %d", feature_size_, kFeatureElementCount); return kTfLiteError; } - const int32_t time_in_ms = TimeInMilliseconds(); // Quantize the time into steps as long as each window stride, so we can // figure out which audio data we need to fetch. - const int last_step = (g_last_time_in_ms / kFeatureSliceStrideMs); + const int last_step = (last_time_in_ms / kFeatureSliceStrideMs); const int current_step = (time_in_ms / kFeatureSliceStrideMs); - g_last_time_in_ms = time_in_ms; int slices_needed = current_step - last_step; // If this is the first call, make sure we don't use any cached information. - if (g_is_first_run) { - g_is_first_run = false; + if (is_first_run_) { + is_first_run_ = false; slices_needed = kFeatureSliceCount; } if (slices_needed > kFeatureSliceCount) { diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h index a86c56ebf05..ee3a480e947 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h @@ -38,11 +38,15 @@ class FeatureProvider { // Fills the feature data with information from audio inputs, and returns how // many feature slices were updated. TfLiteStatus PopulateFeatureData(tflite::ErrorReporter* error_reporter, + int32_t last_time_in_ms, int32_t time_in_ms, int* how_many_new_slices); private: int feature_size_; uint8_t* feature_data_; + // Make sure we don't try to use cached information if this is the first call + // into the provider. + bool is_first_run_; }; #endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc index 1e52aec8d27..556cbfe799b 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc @@ -30,7 +30,8 @@ TF_LITE_MICRO_TEST(TestFeatureProvider) { int how_many_new_slices = 0; TfLiteStatus populate_status = feature_provider.PopulateFeatureData( - error_reporter, &how_many_new_slices); + error_reporter, /* last_time_in_ms= */ 0, /* time_in_ms= */ 10000, + &how_many_new_slices); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, populate_status); TF_LITE_MICRO_EXPECT_EQ(kFeatureSliceCount, how_many_new_slices); } diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc index 1890c25cf2b..515f82fcbc4 100644 --- a/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h" #include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h" #include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" #include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" #include "tensorflow/lite/experimental/micro/micro_error_reporter.h" @@ -68,16 +70,21 @@ int main(int argc, char* argv[]) { FeatureProvider feature_provider(kFeatureElementCount, model_input->data.uint8); + RecognizeCommands recognizer(error_reporter); + + int32_t previous_time = 0; // Keep reading and analysing audio data in an infinite loop. while (true) { // Fetch the spectrogram for the current time. + const int32_t current_time = TimeInMilliseconds(); int how_many_new_slices = 0; TfLiteStatus feature_status = feature_provider.PopulateFeatureData( - error_reporter, &how_many_new_slices); + error_reporter, previous_time, current_time, &how_many_new_slices); if (feature_status != kTfLiteOk) { error_reporter->Report("Feature generation failed"); return 1; } + previous_time = current_time; // If no new audio samples have been received since last time, don't bother // running the network model. if (how_many_new_slices == 0) { @@ -105,7 +112,19 @@ int main(int argc, char* argv[]) { } } - error_reporter->Report("Heard %s", kCategoryLabels[top_category_index]); + const char* found_command = nullptr; + uint8_t score = 0; + bool is_new_command = false; + TfLiteStatus process_status = recognizer.ProcessLatestResults( + output, current_time, &found_command, &score, &is_new_command); + if (process_status != kTfLiteOk) { + error_reporter->Report( + "RecognizeCommands::ProcessLatestResults() failed"); + return 1; + } + if (is_new_command) { + error_reporter->Report("Heard %s (%d)", found_command, score); + } } return 0; diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc new file mode 100644 index 00000000000..9366dc71e0d --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h" + +#include + +RecognizeCommands::RecognizeCommands(tflite::ErrorReporter* error_reporter, + int32_t average_window_duration_ms, + uint8_t detection_threshold, + int32_t suppression_ms, + int32_t minimum_count) + : error_reporter_(error_reporter), + average_window_duration_ms_(average_window_duration_ms), + detection_threshold_(detection_threshold), + suppression_ms_(suppression_ms), + minimum_count_(minimum_count), + previous_results_(error_reporter) { + previous_top_label_ = "_silence_"; + previous_top_label_time_ = 0; +} + +TfLiteStatus RecognizeCommands::ProcessLatestResults( + const TfLiteTensor* latest_results, const int32_t current_time_ms, + const char** found_command, uint8_t* score, bool* is_new_command) { + if ((latest_results->dims->size != 2) || + (latest_results->dims->data[0] != 1) || + (latest_results->dims->data[1] != kCategoryCount)) { + error_reporter_->Report( + "The results for recognition should contain %d elements, but there are " + "%d in an %d-dimensional shape", + kCategoryCount, latest_results->dims->data[1], + latest_results->dims->size); + return kTfLiteError; + } + + if (latest_results->type != kTfLiteUInt8) { + error_reporter_->Report( + "The results for recognition should be uint8 elements, but are %d", + latest_results->type); + return kTfLiteError; + } + + if ((!previous_results_.empty()) && + (current_time_ms < previous_results_.front().time_)) { + error_reporter_->Report( + "Results must be fed in increasing time order, but received a " + "timestamp of %d that was earlier than the previous one of %d", + current_time_ms, previous_results_.front().time_); + return kTfLiteError; + } + + // Add the latest results to the head of the queue. + previous_results_.push_back({current_time_ms, latest_results->data.uint8}); + + // Prune any earlier results that are too old for the averaging window. + const int64_t time_limit = current_time_ms - average_window_duration_ms_; + while ((!previous_results_.empty()) && + previous_results_.front().time_ < time_limit) { + previous_results_.pop_front(); + } + + // If there are too few results, assume the result will be unreliable and + // bail. + const int64_t how_many_results = previous_results_.size(); + const int64_t earliest_time = previous_results_.front().time_; + const int64_t samples_duration = current_time_ms - earliest_time; + if ((how_many_results < minimum_count_) || + (samples_duration < (average_window_duration_ms_ / 4))) { + *found_command = previous_top_label_; + *score = 0; + *is_new_command = false; + return kTfLiteOk; + } + + // Calculate the average score across all the results in the window. + int32_t average_scores[kCategoryCount]; + for (int offset = 0; offset < previous_results_.size(); ++offset) { + PreviousResultsQueue::Result previous_result = + previous_results_.from_front(offset); + const uint8_t* scores = previous_result.scores_; + for (int i = 0; i < kCategoryCount; ++i) { + if (offset == 0) { + average_scores[i] = scores[i]; + } else { + average_scores[i] += scores[i]; + } + } + } + for (int i = 0; i < kCategoryCount; ++i) { + average_scores[i] /= how_many_results; + } + + // Find the current highest scoring category. + int current_top_index = 0; + int32_t current_top_score = 0; + for (int i = 0; i < kCategoryCount; ++i) { + if (average_scores[i] > current_top_score) { + current_top_score = average_scores[i]; + current_top_index = i; + } + } + const char* current_top_label = kCategoryLabels[current_top_index]; + + // If we've recently had another label trigger, assume one that occurs too + // soon afterwards is a bad result. + int64_t time_since_last_top; + if ((previous_top_label_ == kCategoryLabels[0]) || + (previous_top_label_time_ == std::numeric_limits::min())) { + time_since_last_top = std::numeric_limits::max(); + } else { + time_since_last_top = current_time_ms - previous_top_label_time_; + } + if ((current_top_score > detection_threshold_) && + (current_top_label != previous_top_label_) && + (time_since_last_top > suppression_ms_)) { + previous_top_label_ = current_top_label; + previous_top_label_time_ = current_time_ms; + *is_new_command = true; + } else { + *is_new_command = false; + } + *found_command = current_top_label; + *score = current_top_score; + + return kTfLiteOk; +} diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h new file mode 100644 index 00000000000..adefffe8500 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h @@ -0,0 +1,158 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_ + +#include + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" + +// Partial implementation of std::dequeue, just providing the functionality +// that's needed to keep a record of previous neural network results over a +// short time period, so they can be averaged together to produce a more +// accurate overall prediction. This doesn't use any dynamic memory allocation +// so it's a better fit for microcontroller applications, but this does mean +// there are hard limits on the number of results it can store. +class PreviousResultsQueue { + public: + PreviousResultsQueue(tflite::ErrorReporter* error_reporter) + : error_reporter_(error_reporter), front_index_(0), size_(0) {} + + // Data structure that holds an inference result, and the time when it + // was recorded. + struct Result { + Result() : time_(0), scores_() {} + Result(int32_t time, uint8_t* scores) : time_(time) { + for (int i = 0; i < kCategoryCount; ++i) { + scores_[i] = scores[i]; + } + } + int32_t time_; + uint8_t scores_[kCategoryCount]; + }; + + int size() { return size_; } + bool empty() { return size_ == 0; } + Result& front() { return results_[front_index_]; } + Result& back() { + int back_index = front_index_ + (size_ - 1); + if (back_index >= kMaxResults) { + back_index -= kMaxResults; + } + return results_[back_index]; + } + + void push_back(const Result& entry) { + if (size() >= kMaxResults) { + error_reporter_->Report( + "Couldn't push_back latest result, too many already!"); + return; + } + size_ += 1; + back() = entry; + } + + Result pop_front() { + if (size() <= 0) { + error_reporter_->Report("Couldn't pop_front result, none present!"); + return Result(); + } + Result result = front(); + front_index_ += 1; + if (front_index_ >= kMaxResults) { + front_index_ = 0; + } + size_ -= 1; + return result; + } + + // Most of the functions are duplicates of dequeue containers, but this + // is a helper that makes it easy to iterate through the contents of the + // queue. + Result& from_front(int offset) { + if ((offset < 0) || (offset >= size_)) { + error_reporter_->Report("Attempt to read beyond the end of the queue!"); + offset = size_ - 1; + } + int index = front_index_ + offset; + if (index >= kMaxResults) { + index -= kMaxResults; + } + return results_[index]; + } + + private: + tflite::ErrorReporter* error_reporter_; + static constexpr int kMaxResults = 50; + Result results_[kMaxResults]; + + int front_index_; + int size_; +}; + +// This class is designed to apply a very primitive decoding model on top of the +// instantaneous results from running an audio recognition model on a single +// window of samples. It applies smoothing over time so that noisy individual +// label scores are averaged, increasing the confidence that apparent matches +// are real. +// To use it, you should create a class object with the configuration you +// want, and then feed results from running a TensorFlow model into the +// processing method. The timestamp for each subsequent call should be +// increasing from the previous, since the class is designed to process a stream +// of data over time. +class RecognizeCommands { + public: + // labels should be a list of the strings associated with each one-hot score. + // The window duration controls the smoothing. Longer durations will give a + // higher confidence that the results are correct, but may miss some commands. + // The detection threshold has a similar effect, with high values increasing + // the precision at the cost of recall. The minimum count controls how many + // results need to be in the averaging window before it's seen as a reliable + // average. This prevents erroneous results when the averaging window is + // initially being populated for example. The suppression argument disables + // further recognitions for a set time after one has been triggered, which can + // help reduce spurious recognitions. + explicit RecognizeCommands(tflite::ErrorReporter* error_reporter, + int32_t average_window_duration_ms = 1000, + uint8_t detection_threshold = 51, + int32_t suppression_ms = 500, + int32_t minimum_count = 3); + + // Call this with the results of running a model on sample data. + TfLiteStatus ProcessLatestResults(const TfLiteTensor* latest_results, + const int32_t current_time_ms, + const char** found_command, uint8_t* score, + bool* is_new_command); + + private: + // Configuration + tflite::ErrorReporter* error_reporter_; + int32_t average_window_duration_ms_; + uint8_t detection_threshold_; + int32_t suppression_ms_; + int32_t minimum_count_; + + // Working variables + PreviousResultsQueue previous_results_; + int previous_results_head_; + int previous_results_tail_; + const char* previous_top_label_; + int32_t previous_top_label_time_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_ diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc new file mode 100644 index 00000000000..f0cc73f10b3 --- /dev/null +++ b/tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc @@ -0,0 +1,207 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h" + +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/test_utils.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(PreviousResultsQueueBasic) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + PreviousResultsQueue queue(error_reporter); + TF_LITE_MICRO_EXPECT_EQ(0, queue.size()); + + uint8_t scores_a[4] = {0, 0, 0, 1}; + queue.push_back({0, scores_a}); + TF_LITE_MICRO_EXPECT_EQ(1, queue.size()); + TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_); + TF_LITE_MICRO_EXPECT_EQ(0, queue.back().time_); + + uint8_t scores_b[4] = {0, 0, 1, 0}; + queue.push_back({1, scores_b}); + TF_LITE_MICRO_EXPECT_EQ(2, queue.size()); + TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_); + TF_LITE_MICRO_EXPECT_EQ(1, queue.back().time_); + + PreviousResultsQueue::Result pop_result = queue.pop_front(); + TF_LITE_MICRO_EXPECT_EQ(0, pop_result.time_); + TF_LITE_MICRO_EXPECT_EQ(1, queue.size()); + TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_); + TF_LITE_MICRO_EXPECT_EQ(1, queue.back().time_); + + uint8_t scores_c[4] = {0, 1, 0, 0}; + queue.push_back({2, scores_c}); + TF_LITE_MICRO_EXPECT_EQ(2, queue.size()); + TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_); + TF_LITE_MICRO_EXPECT_EQ(2, queue.back().time_); +} + +TF_LITE_MICRO_TEST(PreviousResultsQueuePushPop) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + PreviousResultsQueue queue(error_reporter); + TF_LITE_MICRO_EXPECT_EQ(0, queue.size()); + + for (int i = 0; i < 123; ++i) { + uint8_t scores[4] = {0, 0, 0, 1}; + queue.push_back({i, scores}); + TF_LITE_MICRO_EXPECT_EQ(1, queue.size()); + TF_LITE_MICRO_EXPECT_EQ(i, queue.front().time_); + TF_LITE_MICRO_EXPECT_EQ(i, queue.back().time_); + + PreviousResultsQueue::Result pop_result = queue.pop_front(); + TF_LITE_MICRO_EXPECT_EQ(i, pop_result.time_); + TF_LITE_MICRO_EXPECT_EQ(0, queue.size()); + } +} + +TF_LITE_MICRO_TEST(RecognizeCommandsTestBasic) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + RecognizeCommands recognize_commands(error_reporter); + + TfLiteTensor results = tflite::testing::CreateQuantizedTensor( + {255, 0, 0, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}), + "input_tensor", 0.0f, 128.0f); + + const char* found_command; + uint8_t score; + bool is_new_command; + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, recognize_commands.ProcessLatestResults( + &results, 0, &found_command, &score, &is_new_command)); +} + +TF_LITE_MICRO_TEST(RecognizeCommandsTestFindCommands) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + RecognizeCommands recognize_commands(error_reporter, 1000, 51); + + TfLiteTensor yes_results = tflite::testing::CreateQuantizedTensor( + {0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}), + "input_tensor", 0.0f, 128.0f); + + bool has_found_new_command = false; + const char* new_command; + for (int i = 0; i < 10; ++i) { + const char* found_command; + uint8_t score; + bool is_new_command; + int32_t current_time_ms = 0 + (i * 100); + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, recognize_commands.ProcessLatestResults( + &yes_results, current_time_ms, &found_command, &score, + &is_new_command)); + if (is_new_command) { + TF_LITE_MICRO_EXPECT(!has_found_new_command); + has_found_new_command = true; + new_command = found_command; + } + } + TF_LITE_MICRO_EXPECT(has_found_new_command); + TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("yes", new_command)); + + TfLiteTensor no_results = tflite::testing::CreateQuantizedTensor( + {0, 0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 4}), + "input_tensor", 0.0f, 128.0f); + has_found_new_command = false; + new_command = ""; + uint8_t score; + for (int i = 0; i < 10; ++i) { + const char* found_command; + bool is_new_command; + int32_t current_time_ms = 1000 + (i * 100); + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, recognize_commands.ProcessLatestResults( + &no_results, current_time_ms, &found_command, &score, + &is_new_command)); + if (is_new_command) { + TF_LITE_MICRO_EXPECT(!has_found_new_command); + has_found_new_command = true; + new_command = found_command; + } + } + TF_LITE_MICRO_EXPECT(has_found_new_command); + TF_LITE_MICRO_EXPECT_EQ(231, score); + TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("no", new_command)); +} + +TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputLength) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + RecognizeCommands recognize_commands(error_reporter, 1000, 51); + + TfLiteTensor bad_results = tflite::testing::CreateQuantizedTensor( + {0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 3}), + "input_tensor", 0.0f, 128.0f); + + const char* found_command; + uint8_t score; + bool is_new_command; + TF_LITE_MICRO_EXPECT_NE( + kTfLiteOk, recognize_commands.ProcessLatestResults( + &bad_results, 0, &found_command, &score, &is_new_command)); +} + +TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputTimes) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + RecognizeCommands recognize_commands(error_reporter, 1000, 51); + + TfLiteTensor results = tflite::testing::CreateQuantizedTensor( + {0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}), + "input_tensor", 0.0f, 128.0f); + + const char* found_command; + uint8_t score; + bool is_new_command; + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, recognize_commands.ProcessLatestResults( + &results, 100, &found_command, &score, &is_new_command)); + TF_LITE_MICRO_EXPECT_NE( + kTfLiteOk, recognize_commands.ProcessLatestResults( + &results, 0, &found_command, &score, &is_new_command)); +} + +TF_LITE_MICRO_TEST(RecognizeCommandsTestTooFewInputs) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + RecognizeCommands recognize_commands(error_reporter, 1000, 51); + + TfLiteTensor results = tflite::testing::CreateQuantizedTensor( + {0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}), + "input_tensor", 0.0f, 128.0f); + + const char* found_command; + uint8_t score; + bool is_new_command; + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, recognize_commands.ProcessLatestResults( + &results, 100, &found_command, &score, &is_new_command)); + TF_LITE_MICRO_EXPECT_EQ(0, score); + TF_LITE_MICRO_EXPECT_EQ(false, is_new_command); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/kernels/BUILD b/tensorflow/lite/experimental/micro/kernels/BUILD index a54fd41760d..47ac85c6054 100644 --- a/tensorflow/lite/experimental/micro/kernels/BUILD +++ b/tensorflow/lite/experimental/micro/kernels/BUILD @@ -48,22 +48,6 @@ cc_library( ], ) -cc_library( - name = "test_utils", - srcs = [ - ], - hdrs = [ - "test_utils.h", - ], - copts = tflite_copts(), - deps = [ - "//tensorflow/lite/c:c_api_internal", - "//tensorflow/lite/core/api", - "//tensorflow/lite/experimental/micro:micro_framework", - "//tensorflow/lite/experimental/micro/testing:micro_test", - ], -) - tflite_micro_cc_test( name = "depthwise_conv_test", srcs = [ @@ -71,7 +55,6 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":test_utils", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", "//tensorflow/lite/experimental/micro/testing:micro_test", @@ -85,7 +68,6 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":test_utils", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", "//tensorflow/lite/experimental/micro/testing:micro_test", @@ -99,7 +81,6 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":test_utils", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", "//tensorflow/lite/experimental/micro/testing:micro_test", diff --git a/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc index f70437a4b94..05ba8798c0d 100644 --- a/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/experimental/micro/kernels/depthwise_conv_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" -#include "tensorflow/lite/experimental/micro/kernels/test_utils.h" #include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h" #include "tensorflow/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/experimental/micro/kernels/fully_connected_test.cc b/tensorflow/lite/experimental/micro/kernels/fully_connected_test.cc index 300f8aaf78a..c2e1446848d 100644 --- a/tensorflow/lite/experimental/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/experimental/micro/kernels/fully_connected_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" -#include "tensorflow/lite/experimental/micro/kernels/test_utils.h" #include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h" #include "tensorflow/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/experimental/micro/kernels/softmax_test.cc b/tensorflow/lite/experimental/micro/kernels/softmax_test.cc index 7253b3be8ce..8933b6c0ed0 100644 --- a/tensorflow/lite/experimental/micro/kernels/softmax_test.cc +++ b/tensorflow/lite/experimental/micro/kernels/softmax_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" -#include "tensorflow/lite/experimental/micro/kernels/test_utils.h" #include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h" #include "tensorflow/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/experimental/micro/testing/BUILD b/tensorflow/lite/experimental/micro/testing/BUILD index 5a31a709ca3..1623df5b865 100644 --- a/tensorflow/lite/experimental/micro/testing/BUILD +++ b/tensorflow/lite/experimental/micro/testing/BUILD @@ -10,8 +10,10 @@ cc_library( name = "micro_test", hdrs = [ "micro_test.h", + "test_utils.h", ], deps = [ + "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/experimental/micro:micro_framework", ], ) diff --git a/tensorflow/lite/experimental/micro/kernels/test_utils.h b/tensorflow/lite/experimental/micro/testing/test_utils.h similarity index 91% rename from tensorflow/lite/experimental/micro/kernels/test_utils.h rename to tensorflow/lite/experimental/micro/testing/test_utils.h index 95f2d8a9d21..e37eaf46e08 100644 --- a/tensorflow/lite/experimental/micro/kernels/test_utils.h +++ b/tensorflow/lite/experimental/micro/testing/test_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_ #include #include @@ -21,8 +21,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" -#include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/lite/experimental/micro/micro_error_reporter.h" #include "tensorflow/lite/experimental/micro/testing/micro_test.h" namespace tflite { @@ -164,7 +163,20 @@ inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list data, return CreateQuantized32Tensor(data.begin(), dims, name, min, max); } +// Do a simple string comparison for testing purposes, without requiring the +// standard C library. +inline int TestStrcmp(const char* a, const char* b) { + if ((a == nullptr) || (b == nullptr)) { + return -1; + } + while ((*a != 0) && (*a == *b)) { + a++; + b++; + } + return *(const unsigned char*)a - *(const unsigned char*)b; +} + } // namespace testing } // namespace tflite -#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_