Wire TFLite thread count to the eager context
PiperOrigin-RevId: 226214825
This commit is contained in:
parent
81e98fcb01
commit
2d1a3052e1
@ -30,6 +30,21 @@ namespace flex {
|
||||
namespace delegate {
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
// If the TensorFlow Lite thread count is explicitly configured, use it,
|
||||
// otherwise rely on the default TensorFlow threading behavior.
|
||||
tensorflow::SessionOptions session_options;
|
||||
if (context->recommended_num_threads > 0) {
|
||||
session_options.config.set_intra_op_parallelism_threads(
|
||||
context->recommended_num_threads);
|
||||
}
|
||||
|
||||
if (!reinterpret_cast<DelegateData*>(delegate->data_)
|
||||
->Prepare(session_options)
|
||||
.ok()) {
|
||||
context->ReportError(context, "Failed to initialize TensorFlow context.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Get the nodes in the current execution plan. Interpreter owns this array.
|
||||
TfLiteIntArray* plan;
|
||||
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
|
||||
@ -118,20 +133,11 @@ AcquireFlexDelegate() {
|
||||
}
|
||||
|
||||
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
|
||||
std::unique_ptr<flex::DelegateData> delegate_data;
|
||||
if (!flex::DelegateData::Create(&delegate_data).ok()) {
|
||||
fprintf(stderr, "Unable to initialize TensorFlow context.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return std::unique_ptr<FlexDelegate>(
|
||||
new FlexDelegate(std::move(delegate_data)));
|
||||
return std::unique_ptr<FlexDelegate>(new FlexDelegate());
|
||||
}
|
||||
|
||||
FlexDelegate::FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data)
|
||||
: TfLiteDelegate(TfLiteDelegateCreate()),
|
||||
delegate_data_(std::move(delegate_data)) {
|
||||
data_ = delegate_data_.get();
|
||||
FlexDelegate::FlexDelegate() : TfLiteDelegate(TfLiteDelegateCreate()) {
|
||||
data_ = &delegate_data_;
|
||||
Prepare = &flex::delegate::Prepare;
|
||||
CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle;
|
||||
flags = kTfLiteDelegateFlagsAllowDynamicTensors;
|
||||
|
@ -49,9 +49,9 @@ class FlexDelegate : public TfLiteDelegate {
|
||||
~FlexDelegate();
|
||||
|
||||
private:
|
||||
explicit FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data);
|
||||
FlexDelegate();
|
||||
|
||||
std::unique_ptr<flex::DelegateData> delegate_data_;
|
||||
flex::DelegateData delegate_data_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -20,29 +20,32 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
|
||||
DelegateData::DelegateData() {}
|
||||
|
||||
DelegateData::~DelegateData() {}
|
||||
|
||||
tensorflow::Status DelegateData::Prepare(
|
||||
const tensorflow::SessionOptions& session_options) {
|
||||
if (eager_context_) {
|
||||
return tensorflow::Status();
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
|
||||
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
|
||||
tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
|
||||
&devices));
|
||||
session_options, "/job:localhost/replica:0/task:0", &devices));
|
||||
|
||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
|
||||
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
|
||||
// Note that Rendezvous is ref-counted so it will be automatically deleted.
|
||||
tensorflow::Rendezvous* rendezvous =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
data->reset(new DelegateData(new tensorflow::EagerContext(
|
||||
tensorflow::SessionOptions(),
|
||||
eager_context_.reset(new tensorflow::EagerContext(
|
||||
session_options,
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
/*async=*/false, std::move(device_mgr), rendezvous)));
|
||||
/*async=*/false, std::move(device_mgr), rendezvous));
|
||||
return tensorflow::Status();
|
||||
}
|
||||
|
||||
DelegateData::DelegateData(tensorflow::EagerContext* eager_context)
|
||||
: eager_context_(eager_context) {}
|
||||
|
||||
DelegateData::~DelegateData() {}
|
||||
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
@ -15,21 +15,30 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/flex/buffer_map.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/lite/delegates/flex/buffer_map.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
|
||||
// Data kept by the Flex delegate for the lifetime of an Interpreter.
|
||||
//
|
||||
// Note: This class is *not* thread-safe; any dependent delegates should not be
|
||||
// used concurrently.
|
||||
class DelegateData {
|
||||
public:
|
||||
// Create a new DelegateData, initialized with a newly-created EagerContext.
|
||||
static tensorflow::Status Create(std::unique_ptr<DelegateData>* data);
|
||||
|
||||
DelegateData();
|
||||
~DelegateData();
|
||||
|
||||
// Prepare the necessary EagerContext and data for execution.
|
||||
// This must be called at least once before execution. After preparation
|
||||
// succeeds, redundant calls will be ignored (even if the session_options
|
||||
// differ).
|
||||
tensorflow::Status Prepare(const tensorflow::SessionOptions& session_options);
|
||||
|
||||
// The EagerContext that is required for execution of Flex Ops.
|
||||
// Note: The context is lazily created after the first call to |Prepare()|.
|
||||
tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
|
||||
|
||||
// Map from TF Lite tensor index to TensorFlow tensor for a given context.
|
||||
@ -38,8 +47,7 @@ class DelegateData {
|
||||
}
|
||||
|
||||
private:
|
||||
explicit DelegateData(tensorflow::EagerContext* eager_context);
|
||||
|
||||
// Will be null until Prepare() is called and completes successfully.
|
||||
std::unique_ptr<tensorflow::EagerContext> eager_context_;
|
||||
// TODO(b/112439500): Clean up stale BufferMap instances after adding the
|
||||
// necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.
|
||||
|
@ -24,18 +24,20 @@ namespace flex {
|
||||
namespace {
|
||||
|
||||
TEST(DelegateDataTest, Basic) {
|
||||
std::unique_ptr<DelegateData> data;
|
||||
DelegateData data;
|
||||
// We only check for success because it is hard to make initialization fail.
|
||||
// It only happens if we manage to not link the CPU device factory into the
|
||||
// binary.
|
||||
EXPECT_TRUE(DelegateData::Create(&data).ok());
|
||||
tensorflow::SessionOptions session_options;
|
||||
session_options.config.set_intra_op_parallelism_threads(2);
|
||||
EXPECT_TRUE(data.Prepare(session_options).ok());
|
||||
|
||||
TfLiteContext dummy_context1 = {};
|
||||
TfLiteContext dummy_context2 = {};
|
||||
EXPECT_NE(data->GetEagerContext(), nullptr);
|
||||
EXPECT_NE(data->GetBufferMap(&dummy_context1), nullptr);
|
||||
EXPECT_NE(data->GetBufferMap(&dummy_context1),
|
||||
data->GetBufferMap(&dummy_context2));
|
||||
EXPECT_NE(data.GetEagerContext(), nullptr);
|
||||
EXPECT_NE(data.GetBufferMap(&dummy_context1), nullptr);
|
||||
EXPECT_NE(data.GetBufferMap(&dummy_context1),
|
||||
data.GetBufferMap(&dummy_context2));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -252,6 +252,56 @@ TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DelegateTest, SingleThreaded) {
|
||||
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
|
||||
AddTfOp(testing::kUnpack, {0}, {1, 2});
|
||||
AddTfOp(testing::kUnpack, {3}, {4, 5});
|
||||
AddTfOp(testing::kAdd, {1, 4}, {6});
|
||||
AddTfOp(testing::kAdd, {2, 5}, {7});
|
||||
AddTfOp(testing::kMul, {6, 7}, {8});
|
||||
|
||||
// Explicitly disable multi-threading before installing the delegate.
|
||||
interpreter_->SetNumThreads(1);
|
||||
ConfigureDelegate();
|
||||
|
||||
SetShape(0, {2, 2, 1});
|
||||
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
SetShape(3, {2, 2, 1});
|
||||
SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
|
||||
// Invocation should behave as expected.
|
||||
ASSERT_TRUE(Invoke());
|
||||
|
||||
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
|
||||
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
|
||||
ASSERT_EQ(GetType(8), kTfLiteFloat32);
|
||||
}
|
||||
|
||||
TEST_F(DelegateTest, MultiThreaded) {
|
||||
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
|
||||
AddTfOp(testing::kUnpack, {0}, {1, 2});
|
||||
AddTfOp(testing::kUnpack, {3}, {4, 5});
|
||||
AddTfOp(testing::kAdd, {1, 4}, {6});
|
||||
AddTfOp(testing::kAdd, {2, 5}, {7});
|
||||
AddTfOp(testing::kMul, {6, 7}, {8});
|
||||
|
||||
// Explicitly enable multi-threading before installing the delegate.
|
||||
interpreter_->SetNumThreads(4);
|
||||
ConfigureDelegate();
|
||||
|
||||
SetShape(0, {2, 2, 1});
|
||||
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
SetShape(3, {2, 2, 1});
|
||||
SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
|
||||
|
||||
// Invocation should behave as expected.
|
||||
ASSERT_TRUE(Invoke());
|
||||
|
||||
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
|
||||
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
|
||||
ASSERT_EQ(GetType(8), kTfLiteFloat32);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
@ -39,20 +39,13 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
|
||||
class KernelTest : public testing::FlexModelTest {
|
||||
public:
|
||||
KernelTest() {
|
||||
CHECK(DelegateData::Create(&delegate_data_).ok());
|
||||
CHECK(delegate_data_.Prepare(tensorflow::SessionOptions{}).ok());
|
||||
interpreter_.reset(new Interpreter(&error_reporter_));
|
||||
}
|
||||
|
||||
~KernelTest() override {
|
||||
// The data needs to be released before the interpreter because the
|
||||
// interpreter references the data.
|
||||
delegate_data_.reset();
|
||||
interpreter_.reset();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ConfigureDelegate(T prepare_function) {
|
||||
delegate_.data_ = delegate_data_.get();
|
||||
delegate_.data_ = &delegate_data_;
|
||||
delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
|
||||
delegate_.FreeBufferHandle = nullptr;
|
||||
delegate_.Prepare = prepare_function;
|
||||
@ -71,7 +64,7 @@ class KernelTest : public testing::FlexModelTest {
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<DelegateData> delegate_data_;
|
||||
DelegateData delegate_data_;
|
||||
TfLiteDelegate delegate_;
|
||||
};
|
||||
|
||||
|
@ -316,18 +316,13 @@ void BenchmarkTfLiteModel::Init() {
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
#endif
|
||||
|
||||
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
|
||||
const int32_t num_threads = params_.Get<int32_t>("num_threads");
|
||||
tflite::InterpreterBuilder(*model, resolver)(&interpreter, num_threads);
|
||||
if (!interpreter) {
|
||||
TFLITE_LOG(FATAL) << "Failed to construct interpreter";
|
||||
}
|
||||
profiling_listener_.SetInterpreter(interpreter.get());
|
||||
|
||||
const int32_t num_threads = params_.Get<int32_t>("num_threads");
|
||||
|
||||
if (num_threads != -1) {
|
||||
interpreter->SetNumThreads(num_threads);
|
||||
}
|
||||
|
||||
bool use_nnapi = params_.Get<bool>("use_nnapi");
|
||||
|
||||
interpreter->UseNNAPI(use_nnapi);
|
||||
|
Loading…
Reference in New Issue
Block a user