Wire TFLite thread count to the eager context
PiperOrigin-RevId: 226214825
This commit is contained in:
parent
81e98fcb01
commit
2d1a3052e1
@ -30,6 +30,21 @@ namespace flex {
|
|||||||
namespace delegate {
|
namespace delegate {
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||||
|
// If the TensorFlow Lite thread count is explicitly configured, use it,
|
||||||
|
// otherwise rely on the default TensorFlow threading behavior.
|
||||||
|
tensorflow::SessionOptions session_options;
|
||||||
|
if (context->recommended_num_threads > 0) {
|
||||||
|
session_options.config.set_intra_op_parallelism_threads(
|
||||||
|
context->recommended_num_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!reinterpret_cast<DelegateData*>(delegate->data_)
|
||||||
|
->Prepare(session_options)
|
||||||
|
.ok()) {
|
||||||
|
context->ReportError(context, "Failed to initialize TensorFlow context.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
// Get the nodes in the current execution plan. Interpreter owns this array.
|
// Get the nodes in the current execution plan. Interpreter owns this array.
|
||||||
TfLiteIntArray* plan;
|
TfLiteIntArray* plan;
|
||||||
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
|
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
|
||||||
@ -118,20 +133,11 @@ AcquireFlexDelegate() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
|
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
|
||||||
std::unique_ptr<flex::DelegateData> delegate_data;
|
return std::unique_ptr<FlexDelegate>(new FlexDelegate());
|
||||||
if (!flex::DelegateData::Create(&delegate_data).ok()) {
|
|
||||||
fprintf(stderr, "Unable to initialize TensorFlow context.\n");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::unique_ptr<FlexDelegate>(
|
|
||||||
new FlexDelegate(std::move(delegate_data)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FlexDelegate::FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data)
|
FlexDelegate::FlexDelegate() : TfLiteDelegate(TfLiteDelegateCreate()) {
|
||||||
: TfLiteDelegate(TfLiteDelegateCreate()),
|
data_ = &delegate_data_;
|
||||||
delegate_data_(std::move(delegate_data)) {
|
|
||||||
data_ = delegate_data_.get();
|
|
||||||
Prepare = &flex::delegate::Prepare;
|
Prepare = &flex::delegate::Prepare;
|
||||||
CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle;
|
CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle;
|
||||||
flags = kTfLiteDelegateFlagsAllowDynamicTensors;
|
flags = kTfLiteDelegateFlagsAllowDynamicTensors;
|
||||||
|
@ -49,9 +49,9 @@ class FlexDelegate : public TfLiteDelegate {
|
|||||||
~FlexDelegate();
|
~FlexDelegate();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data);
|
FlexDelegate();
|
||||||
|
|
||||||
std::unique_ptr<flex::DelegateData> delegate_data_;
|
flex::DelegateData delegate_data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -20,29 +20,32 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace flex {
|
namespace flex {
|
||||||
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
|
DelegateData::DelegateData() {}
|
||||||
|
|
||||||
|
DelegateData::~DelegateData() {}
|
||||||
|
|
||||||
|
tensorflow::Status DelegateData::Prepare(
|
||||||
|
const tensorflow::SessionOptions& session_options) {
|
||||||
|
if (eager_context_) {
|
||||||
|
return tensorflow::Status();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
|
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
|
||||||
tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
|
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
&devices));
|
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
|
std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
|
||||||
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
|
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
|
||||||
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
||||||
tensorflow::Rendezvous* rendezvous =
|
tensorflow::Rendezvous* rendezvous =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
data->reset(new DelegateData(new tensorflow::EagerContext(
|
eager_context_.reset(new tensorflow::EagerContext(
|
||||||
tensorflow::SessionOptions(),
|
session_options,
|
||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
/*async=*/false, std::move(device_mgr), rendezvous)));
|
/*async=*/false, std::move(device_mgr), rendezvous));
|
||||||
return tensorflow::Status();
|
return tensorflow::Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
DelegateData::DelegateData(tensorflow::EagerContext* eager_context)
|
|
||||||
: eager_context_(eager_context) {}
|
|
||||||
|
|
||||||
DelegateData::~DelegateData() {}
|
|
||||||
|
|
||||||
} // namespace flex
|
} // namespace flex
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -15,21 +15,30 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
|
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
|
||||||
#define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
|
#define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/flex/buffer_map.h"
|
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
#include "tensorflow/lite/delegates/flex/buffer_map.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace flex {
|
namespace flex {
|
||||||
|
|
||||||
// Data kept by the Flex delegate for the lifetime of an Interpreter.
|
// Data kept by the Flex delegate for the lifetime of an Interpreter.
|
||||||
|
//
|
||||||
|
// Note: This class is *not* thread-safe; any dependent delegates should not be
|
||||||
|
// used concurrently.
|
||||||
class DelegateData {
|
class DelegateData {
|
||||||
public:
|
public:
|
||||||
// Create a new DelegateData, initialized with a newly-created EagerContext.
|
DelegateData();
|
||||||
static tensorflow::Status Create(std::unique_ptr<DelegateData>* data);
|
|
||||||
|
|
||||||
~DelegateData();
|
~DelegateData();
|
||||||
|
|
||||||
|
// Prepare the necessary EagerContext and data for execution.
|
||||||
|
// This must be called at least once before execution. After preparation
|
||||||
|
// succeeds, redundant calls will be ignored (even if the session_options
|
||||||
|
// differ).
|
||||||
|
tensorflow::Status Prepare(const tensorflow::SessionOptions& session_options);
|
||||||
|
|
||||||
// The EagerContext that is required for execution of Flex Ops.
|
// The EagerContext that is required for execution of Flex Ops.
|
||||||
|
// Note: The context is lazily created after the first call to |Prepare()|.
|
||||||
tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
|
tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
|
||||||
|
|
||||||
// Map from TF Lite tensor index to TensorFlow tensor for a given context.
|
// Map from TF Lite tensor index to TensorFlow tensor for a given context.
|
||||||
@ -38,8 +47,7 @@ class DelegateData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit DelegateData(tensorflow::EagerContext* eager_context);
|
// Will be null until Prepare() is called and completes successfully.
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::EagerContext> eager_context_;
|
std::unique_ptr<tensorflow::EagerContext> eager_context_;
|
||||||
// TODO(b/112439500): Clean up stale BufferMap instances after adding the
|
// TODO(b/112439500): Clean up stale BufferMap instances after adding the
|
||||||
// necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.
|
// necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.
|
||||||
|
@ -24,18 +24,20 @@ namespace flex {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(DelegateDataTest, Basic) {
|
TEST(DelegateDataTest, Basic) {
|
||||||
std::unique_ptr<DelegateData> data;
|
DelegateData data;
|
||||||
// We only check for success because it is hard to make initialization fail.
|
// We only check for success because it is hard to make initialization fail.
|
||||||
// It only happens if we manage to not link the CPU device factory into the
|
// It only happens if we manage to not link the CPU device factory into the
|
||||||
// binary.
|
// binary.
|
||||||
EXPECT_TRUE(DelegateData::Create(&data).ok());
|
tensorflow::SessionOptions session_options;
|
||||||
|
session_options.config.set_intra_op_parallelism_threads(2);
|
||||||
|
EXPECT_TRUE(data.Prepare(session_options).ok());
|
||||||
|
|
||||||
TfLiteContext dummy_context1 = {};
|
TfLiteContext dummy_context1 = {};
|
||||||
TfLiteContext dummy_context2 = {};
|
TfLiteContext dummy_context2 = {};
|
||||||
EXPECT_NE(data->GetEagerContext(), nullptr);
|
EXPECT_NE(data.GetEagerContext(), nullptr);
|
||||||
EXPECT_NE(data->GetBufferMap(&dummy_context1), nullptr);
|
EXPECT_NE(data.GetBufferMap(&dummy_context1), nullptr);
|
||||||
EXPECT_NE(data->GetBufferMap(&dummy_context1),
|
EXPECT_NE(data.GetBufferMap(&dummy_context1),
|
||||||
data->GetBufferMap(&dummy_context2));
|
data.GetBufferMap(&dummy_context2));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -252,6 +252,56 @@ TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DelegateTest, SingleThreaded) {
|
||||||
|
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
|
||||||
|
AddTfOp(testing::kUnpack, {0}, {1, 2});
|
||||||
|
AddTfOp(testing::kUnpack, {3}, {4, 5});
|
||||||
|
AddTfOp(testing::kAdd, {1, 4}, {6});
|
||||||
|
AddTfOp(testing::kAdd, {2, 5}, {7});
|
||||||
|
AddTfOp(testing::kMul, {6, 7}, {8});
|
||||||
|
|
||||||
|
// Explicitly disable multi-threading before installing the delegate.
|
||||||
|
interpreter_->SetNumThreads(1);
|
||||||
|
ConfigureDelegate();
|
||||||
|
|
||||||
|
SetShape(0, {2, 2, 1});
|
||||||
|
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||||
|
SetShape(3, {2, 2, 1});
|
||||||
|
SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||||
|
|
||||||
|
// Invocation should behave as expected.
|
||||||
|
ASSERT_TRUE(Invoke());
|
||||||
|
|
||||||
|
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
|
||||||
|
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
|
||||||
|
ASSERT_EQ(GetType(8), kTfLiteFloat32);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DelegateTest, MultiThreaded) {
|
||||||
|
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
|
||||||
|
AddTfOp(testing::kUnpack, {0}, {1, 2});
|
||||||
|
AddTfOp(testing::kUnpack, {3}, {4, 5});
|
||||||
|
AddTfOp(testing::kAdd, {1, 4}, {6});
|
||||||
|
AddTfOp(testing::kAdd, {2, 5}, {7});
|
||||||
|
AddTfOp(testing::kMul, {6, 7}, {8});
|
||||||
|
|
||||||
|
// Explicitly enable multi-threading before installing the delegate.
|
||||||
|
interpreter_->SetNumThreads(4);
|
||||||
|
ConfigureDelegate();
|
||||||
|
|
||||||
|
SetShape(0, {2, 2, 1});
|
||||||
|
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||||
|
SetShape(3, {2, 2, 1});
|
||||||
|
SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||||
|
|
||||||
|
// Invocation should behave as expected.
|
||||||
|
ASSERT_TRUE(Invoke());
|
||||||
|
|
||||||
|
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
|
||||||
|
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
|
||||||
|
ASSERT_EQ(GetType(8), kTfLiteFloat32);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace flex
|
} // namespace flex
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -39,20 +39,13 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
|
|||||||
class KernelTest : public testing::FlexModelTest {
|
class KernelTest : public testing::FlexModelTest {
|
||||||
public:
|
public:
|
||||||
KernelTest() {
|
KernelTest() {
|
||||||
CHECK(DelegateData::Create(&delegate_data_).ok());
|
CHECK(delegate_data_.Prepare(tensorflow::SessionOptions{}).ok());
|
||||||
interpreter_.reset(new Interpreter(&error_reporter_));
|
interpreter_.reset(new Interpreter(&error_reporter_));
|
||||||
}
|
}
|
||||||
|
|
||||||
~KernelTest() override {
|
|
||||||
// The data needs to be released before the interpreter because the
|
|
||||||
// interpreter references the data.
|
|
||||||
delegate_data_.reset();
|
|
||||||
interpreter_.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ConfigureDelegate(T prepare_function) {
|
void ConfigureDelegate(T prepare_function) {
|
||||||
delegate_.data_ = delegate_data_.get();
|
delegate_.data_ = &delegate_data_;
|
||||||
delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
|
delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
|
||||||
delegate_.FreeBufferHandle = nullptr;
|
delegate_.FreeBufferHandle = nullptr;
|
||||||
delegate_.Prepare = prepare_function;
|
delegate_.Prepare = prepare_function;
|
||||||
@ -71,7 +64,7 @@ class KernelTest : public testing::FlexModelTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<DelegateData> delegate_data_;
|
DelegateData delegate_data_;
|
||||||
TfLiteDelegate delegate_;
|
TfLiteDelegate delegate_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -316,18 +316,13 @@ void BenchmarkTfLiteModel::Init() {
|
|||||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
|
const int32_t num_threads = params_.Get<int32_t>("num_threads");
|
||||||
|
tflite::InterpreterBuilder(*model, resolver)(&interpreter, num_threads);
|
||||||
if (!interpreter) {
|
if (!interpreter) {
|
||||||
TFLITE_LOG(FATAL) << "Failed to construct interpreter";
|
TFLITE_LOG(FATAL) << "Failed to construct interpreter";
|
||||||
}
|
}
|
||||||
profiling_listener_.SetInterpreter(interpreter.get());
|
profiling_listener_.SetInterpreter(interpreter.get());
|
||||||
|
|
||||||
const int32_t num_threads = params_.Get<int32_t>("num_threads");
|
|
||||||
|
|
||||||
if (num_threads != -1) {
|
|
||||||
interpreter->SetNumThreads(num_threads);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool use_nnapi = params_.Get<bool>("use_nnapi");
|
bool use_nnapi = params_.Get<bool>("use_nnapi");
|
||||||
|
|
||||||
interpreter->UseNNAPI(use_nnapi);
|
interpreter->UseNNAPI(use_nnapi);
|
||||||
|
Loading…
Reference in New Issue
Block a user