Return TfLiteStatus from Interpreter::SetNumThreads()

PiperOrigin-RevId: 318472145
Change-Id: Icb3fd7575c27638e61911f43ca0ebe952236da2c
This commit is contained in:
Dmitry Kovalev 2020-06-26 08:03:50 -07:00 committed by TensorFlower Gardener
parent c9a4b8edd4
commit e91990219b
3 changed files with 10 additions and 9 deletions

View File

@ -302,12 +302,12 @@ TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
void Interpreter::UseNNAPI(bool enable) { primary_subgraph().UseNNAPI(enable); }
void Interpreter::SetNumThreads(int num_threads) {
TfLiteStatus Interpreter::SetNumThreads(int num_threads) {
if (num_threads < -1) {
context_->ReportError(context_,
"num_threads should be >=0 or just -1 to let TFLite "
"runtime set the value.");
return;
return kTfLiteError;
}
for (auto& subgraph : subgraphs_) {
@ -320,6 +320,7 @@ void Interpreter::SetNumThreads(int num_threads) {
c->Refresh(context_);
}
}
return kTfLiteOk;
}
void Interpreter::SetAllowFp16PrecisionForFp32(bool allow) {

View File

@ -371,7 +371,7 @@ class Interpreter {
/// NOTE: num_threads should be >= -1.
/// User may pass -1 to let the TFLite interpreter set the no of threads
/// available to itself.
void SetNumThreads(int num_threads);
TfLiteStatus SetNumThreads(int num_threads);
/// Allow float16 precision for FP32 calculation when possible.
/// default: not allow.

View File

@ -1131,26 +1131,26 @@ TEST_F(InterpreterTest, GetSetResetExternalContexts) {
};
EXPECT_EQ(TestExternalContext::Get(context), nullptr);
interpreter_.SetNumThreads(4);
ASSERT_EQ(interpreter_.SetNumThreads(4), kTfLiteOk);
TestExternalContext::Set(context, &external_context);
EXPECT_EQ(TestExternalContext::Get(context), &external_context);
interpreter_.SetNumThreads(4);
interpreter_.SetNumThreads(5);
ASSERT_EQ(interpreter_.SetNumThreads(4), kTfLiteOk);
ASSERT_EQ(interpreter_.SetNumThreads(5), kTfLiteOk);
EXPECT_EQ(external_context.num_refreshes, 2);
// Reset refresh count to 0
external_context.num_refreshes = 0;
// Below should not call external context refresh
interpreter_.SetNumThreads(-2);
ASSERT_EQ(interpreter_.SetNumThreads(-2), kTfLiteError);
EXPECT_EQ(external_context.num_refreshes, 0);
interpreter_.SetNumThreads(-1);
ASSERT_EQ(interpreter_.SetNumThreads(-1), kTfLiteOk);
EXPECT_EQ(external_context.num_refreshes, 1);
TestExternalContext::Set(context, nullptr);
EXPECT_EQ(TestExternalContext::Get(context), nullptr);
interpreter_.SetNumThreads(4);
ASSERT_EQ(interpreter_.SetNumThreads(4), kTfLiteOk);
}
struct TestCpuBackendContext : public TfLiteInternalBackendContext {