Fail the model evaluation if there's no successful TFLite interpreter invocations in all attempts.
PiperOrigin-RevId: 258908895
This commit is contained in:
parent
69359f86c9
commit
303384cce6
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex> // NOLINT(build/c++11)
|
#include <mutex> // NOLINT(build/c++11)
|
||||||
@ -178,17 +179,17 @@ int Main(int argc, char* argv[]) {
|
|||||||
std::unique_ptr<ImagenetModelEvaluator> evaluator;
|
std::unique_ptr<ImagenetModelEvaluator> evaluator;
|
||||||
if (output_file_path.empty()) {
|
if (output_file_path.empty()) {
|
||||||
LOG(ERROR) << "Invalid output file path.";
|
LOG(ERROR) << "Invalid output file path.";
|
||||||
return 0;
|
return EXIT_FAILURE;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (num_threads <= 0) {
|
if (num_threads <= 0) {
|
||||||
LOG(ERROR) << "Invalid number of threads.";
|
LOG(ERROR) << "Invalid number of threads.";
|
||||||
return 0;
|
return EXIT_FAILURE;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator) !=
|
if (ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator) !=
|
||||||
kTfLiteOk)
|
kTfLiteOk)
|
||||||
return 0;
|
return EXIT_FAILURE;
|
||||||
|
|
||||||
std::ofstream output_stream(output_file_path, std::ios::out);
|
std::ofstream output_stream(output_file_path, std::ios::out);
|
||||||
if (!output_stream) {
|
if (!output_stream) {
|
||||||
@ -210,7 +211,10 @@ int Main(int argc, char* argv[]) {
|
|||||||
absl::make_unique<CSVWriter>(columns, &output_stream));
|
absl::make_unique<CSVWriter>(columns, &output_stream));
|
||||||
evaluator->AddObserver(&results_writer);
|
evaluator->AddObserver(&results_writer);
|
||||||
LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
|
LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
|
||||||
evaluator->EvaluateModel();
|
if (evaluator->EvaluateModel() != kTfLiteOk) {
|
||||||
|
LOG(ERROR) << "Failed to evaluate the model!";
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
if (!proto_output_file_path.empty()) {
|
if (!proto_output_file_path.empty()) {
|
||||||
std::ofstream proto_out_file(proto_output_file_path,
|
std::ofstream proto_out_file(proto_output_file_path,
|
||||||
@ -220,7 +224,7 @@ int Main(int argc, char* argv[]) {
|
|||||||
proto_out_file.close();
|
proto_out_file.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace metrics
|
} // namespace metrics
|
||||||
|
@ -263,7 +263,7 @@ TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
|
|||||||
&all_okay]() {
|
&all_okay]() {
|
||||||
if (EvaluateModelForShard(shard_id, image_label, model_labels, params_,
|
if (EvaluateModelForShard(shard_id, image_label, model_labels, params_,
|
||||||
&observer, params_.num_ranks) != kTfLiteOk) {
|
&observer, params_.num_ranks) != kTfLiteOk) {
|
||||||
all_okay = all_okay && false;
|
all_okay = false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
thread_pool.push_back(std::thread(func));
|
thread_pool.push_back(std::thread(func));
|
||||||
@ -274,7 +274,7 @@ TfLiteStatus ImagenetModelEvaluator::EvaluateModel() const {
|
|||||||
thread.join();
|
thread.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return all_okay ? kTfLiteOk : kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace metrics
|
} // namespace metrics
|
||||||
|
@ -140,7 +140,10 @@ TfLiteStatus TfliteInferenceStage::Run() {
|
|||||||
auto& params = config_.specification().tflite_inference_params();
|
auto& params = config_.specification().tflite_inference_params();
|
||||||
for (int i = 0; i < params.invocations_per_run(); ++i) {
|
for (int i = 0; i < params.invocations_per_run(); ++i) {
|
||||||
int64_t start_us = profiling::time::NowMicros();
|
int64_t start_us = profiling::time::NowMicros();
|
||||||
interpreter_->Invoke();
|
if (interpreter_->Invoke() != kTfLiteOk) {
|
||||||
|
LOG(ERROR) << "TFLite interpreter failed to invoke at run " << i;
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
latency_stats_.UpdateStat(profiling::time::NowMicros() - start_us);
|
latency_stats_.UpdateStat(profiling::time::NowMicros() - start_us);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user