Return TfLiteStatus from Interpreter::SetNumThreads()
PiperOrigin-RevId: 318472145 Change-Id: Icb3fd7575c27638e61911f43ca0ebe952236da2c
This commit is contained in:
parent
c9a4b8edd4
commit
e91990219b
@ -302,12 +302,12 @@ TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
|
||||
|
||||
void Interpreter::UseNNAPI(bool enable) { primary_subgraph().UseNNAPI(enable); }
|
||||
|
||||
void Interpreter::SetNumThreads(int num_threads) {
|
||||
TfLiteStatus Interpreter::SetNumThreads(int num_threads) {
|
||||
if (num_threads < -1) {
|
||||
context_->ReportError(context_,
|
||||
"num_threads should be >=0 or just -1 to let TFLite "
|
||||
"runtime set the value.");
|
||||
return;
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
for (auto& subgraph : subgraphs_) {
|
||||
@ -320,6 +320,7 @@ void Interpreter::SetNumThreads(int num_threads) {
|
||||
c->Refresh(context_);
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void Interpreter::SetAllowFp16PrecisionForFp32(bool allow) {
|
||||
|
||||
@ -371,7 +371,7 @@ class Interpreter {
|
||||
/// NOTE: num_threads should be >= -1.
|
||||
/// User may pass -1 to let the TFLite interpreter set the no of threads
|
||||
/// available to itself.
|
||||
void SetNumThreads(int num_threads);
|
||||
TfLiteStatus SetNumThreads(int num_threads);
|
||||
|
||||
/// Allow float16 precision for FP32 calculation when possible.
|
||||
/// default: not allow.
|
||||
|
||||
@ -1131,26 +1131,26 @@ TEST_F(InterpreterTest, GetSetResetExternalContexts) {
|
||||
};
|
||||
|
||||
EXPECT_EQ(TestExternalContext::Get(context), nullptr);
|
||||
interpreter_.SetNumThreads(4);
|
||||
ASSERT_EQ(interpreter_.SetNumThreads(4), kTfLiteOk);
|
||||
|
||||
TestExternalContext::Set(context, &external_context);
|
||||
EXPECT_EQ(TestExternalContext::Get(context), &external_context);
|
||||
interpreter_.SetNumThreads(4);
|
||||
interpreter_.SetNumThreads(5);
|
||||
ASSERT_EQ(interpreter_.SetNumThreads(4), kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.SetNumThreads(5), kTfLiteOk);
|
||||
EXPECT_EQ(external_context.num_refreshes, 2);
|
||||
|
||||
// Reset refresh count to 0
|
||||
external_context.num_refreshes = 0;
|
||||
// Below should not call external context refresh
|
||||
interpreter_.SetNumThreads(-2);
|
||||
ASSERT_EQ(interpreter_.SetNumThreads(-2), kTfLiteError);
|
||||
EXPECT_EQ(external_context.num_refreshes, 0);
|
||||
|
||||
interpreter_.SetNumThreads(-1);
|
||||
ASSERT_EQ(interpreter_.SetNumThreads(-1), kTfLiteOk);
|
||||
EXPECT_EQ(external_context.num_refreshes, 1);
|
||||
|
||||
TestExternalContext::Set(context, nullptr);
|
||||
EXPECT_EQ(TestExternalContext::Get(context), nullptr);
|
||||
interpreter_.SetNumThreads(4);
|
||||
ASSERT_EQ(interpreter_.SetNumThreads(4), kTfLiteOk);
|
||||
}
|
||||
|
||||
struct TestCpuBackendContext : public TfLiteInternalBackendContext {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user