Wire TFLite thread count to the eager context

PiperOrigin-RevId: 226214825
This commit is contained in:
Jared Duke 2018-12-19 12:25:34 -08:00 committed by TensorFlower Gardener
parent 81e98fcb01
commit 2d1a3052e1
8 changed files with 111 additions and 54 deletions

View File

@ -30,6 +30,21 @@ namespace flex {
namespace delegate {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
// If the TensorFlow Lite thread count is explicitly configured, use it,
// otherwise rely on the default TensorFlow threading behavior.
tensorflow::SessionOptions session_options;
if (context->recommended_num_threads > 0) {
session_options.config.set_intra_op_parallelism_threads(
context->recommended_num_threads);
}
if (!reinterpret_cast<DelegateData*>(delegate->data_)
->Prepare(session_options)
.ok()) {
context->ReportError(context, "Failed to initialize TensorFlow context.");
return kTfLiteError;
}
// Get the nodes in the current execution plan. Interpreter owns this array.
TfLiteIntArray* plan;
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
@ -118,20 +133,11 @@ AcquireFlexDelegate() {
}
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
std::unique_ptr<flex::DelegateData> delegate_data;
if (!flex::DelegateData::Create(&delegate_data).ok()) {
fprintf(stderr, "Unable to initialize TensorFlow context.\n");
return nullptr;
}
return std::unique_ptr<FlexDelegate>(
new FlexDelegate(std::move(delegate_data)));
return std::unique_ptr<FlexDelegate>(new FlexDelegate());
}
FlexDelegate::FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data)
: TfLiteDelegate(TfLiteDelegateCreate()),
delegate_data_(std::move(delegate_data)) {
data_ = delegate_data_.get();
FlexDelegate::FlexDelegate() : TfLiteDelegate(TfLiteDelegateCreate()) {
data_ = &delegate_data_;
Prepare = &flex::delegate::Prepare;
CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle;
flags = kTfLiteDelegateFlagsAllowDynamicTensors;

View File

@ -49,9 +49,9 @@ class FlexDelegate : public TfLiteDelegate {
~FlexDelegate();
private:
explicit FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data);
FlexDelegate();
std::unique_ptr<flex::DelegateData> delegate_data_;
flex::DelegateData delegate_data_;
};
} // namespace tflite

View File

@ -20,29 +20,32 @@ limitations under the License.
namespace tflite {
namespace flex {
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
DelegateData::DelegateData() {}
DelegateData::~DelegateData() {}
tensorflow::Status DelegateData::Prepare(
const tensorflow::SessionOptions& session_options) {
if (eager_context_) {
return tensorflow::Status();
}
std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
&devices));
session_options, "/job:localhost/replica:0/task:0", &devices));
std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
// Note that Rendezvous is ref-counted so it will be automatically deleted.
tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
data->reset(new DelegateData(new tensorflow::EagerContext(
tensorflow::SessionOptions(),
eager_context_.reset(new tensorflow::EagerContext(
session_options,
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
/*async=*/false, std::move(device_mgr), rendezvous)));
/*async=*/false, std::move(device_mgr), rendezvous));
return tensorflow::Status();
}
DelegateData::DelegateData(tensorflow::EagerContext* eager_context)
: eager_context_(eager_context) {}
DelegateData::~DelegateData() {}
} // namespace flex
} // namespace tflite

View File

@ -15,21 +15,30 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
#define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
#include "tensorflow/lite/delegates/flex/buffer_map.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/lite/delegates/flex/buffer_map.h"
namespace tflite {
namespace flex {
// Data kept by the Flex delegate for the lifetime of an Interpreter.
//
// Note: This class is *not* thread-safe; any dependent delegates should not be
// used concurrently.
class DelegateData {
public:
// Create a new DelegateData, initialized with a newly-created EagerContext.
static tensorflow::Status Create(std::unique_ptr<DelegateData>* data);
DelegateData();
~DelegateData();
// Prepare the necessary EagerContext and data for execution.
// This must be called at least once before execution. After preparation
// succeeds, redundant calls will be ignored (even if the session_options
// differ).
tensorflow::Status Prepare(const tensorflow::SessionOptions& session_options);
// The EagerContext that is required for execution of Flex Ops.
// Note: The context is lazily created after the first call to |Prepare()|.
tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
// Map from TF Lite tensor index to TensorFlow tensor for a given context.
@ -38,8 +47,7 @@ class DelegateData {
}
private:
explicit DelegateData(tensorflow::EagerContext* eager_context);
// Will be null until Prepare() is called and completes successfully.
std::unique_ptr<tensorflow::EagerContext> eager_context_;
// TODO(b/112439500): Clean up stale BufferMap instances after adding the
// necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.

View File

@ -24,18 +24,20 @@ namespace flex {
namespace {
TEST(DelegateDataTest, Basic) {
std::unique_ptr<DelegateData> data;
DelegateData data;
// We only check for success because it is hard to make initialization fail.
// It only happens if we manage to not link the CPU device factory into the
// binary.
EXPECT_TRUE(DelegateData::Create(&data).ok());
tensorflow::SessionOptions session_options;
session_options.config.set_intra_op_parallelism_threads(2);
EXPECT_TRUE(data.Prepare(session_options).ok());
TfLiteContext dummy_context1 = {};
TfLiteContext dummy_context2 = {};
EXPECT_NE(data->GetEagerContext(), nullptr);
EXPECT_NE(data->GetBufferMap(&dummy_context1), nullptr);
EXPECT_NE(data->GetBufferMap(&dummy_context1),
data->GetBufferMap(&dummy_context2));
EXPECT_NE(data.GetEagerContext(), nullptr);
EXPECT_NE(data.GetBufferMap(&dummy_context1), nullptr);
EXPECT_NE(data.GetBufferMap(&dummy_context1),
data.GetBufferMap(&dummy_context2));
}
} // namespace

View File

@ -252,6 +252,56 @@ TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
}
}
TEST_F(DelegateTest, SingleThreaded) {
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
AddTfOp(testing::kUnpack, {0}, {1, 2});
AddTfOp(testing::kUnpack, {3}, {4, 5});
AddTfOp(testing::kAdd, {1, 4}, {6});
AddTfOp(testing::kAdd, {2, 5}, {7});
AddTfOp(testing::kMul, {6, 7}, {8});
// Explicitly disable multi-threading before installing the delegate.
interpreter_->SetNumThreads(1);
ConfigureDelegate();
SetShape(0, {2, 2, 1});
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
SetShape(3, {2, 2, 1});
SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
// Invocation should behave as expected.
ASSERT_TRUE(Invoke());
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
ASSERT_EQ(GetType(8), kTfLiteFloat32);
}
TEST_F(DelegateTest, MultiThreaded) {
AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
AddTfOp(testing::kUnpack, {0}, {1, 2});
AddTfOp(testing::kUnpack, {3}, {4, 5});
AddTfOp(testing::kAdd, {1, 4}, {6});
AddTfOp(testing::kAdd, {2, 5}, {7});
AddTfOp(testing::kMul, {6, 7}, {8});
// Explicitly enable multi-threading before installing the delegate.
interpreter_->SetNumThreads(4);
ConfigureDelegate();
SetShape(0, {2, 2, 1});
SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
SetShape(3, {2, 2, 1});
SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
// Invocation should behave as expected.
ASSERT_TRUE(Invoke());
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
ASSERT_EQ(GetType(8), kTfLiteFloat32);
}
} // namespace
} // namespace flex
} // namespace tflite

View File

@ -39,20 +39,13 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
class KernelTest : public testing::FlexModelTest {
public:
KernelTest() {
CHECK(DelegateData::Create(&delegate_data_).ok());
CHECK(delegate_data_.Prepare(tensorflow::SessionOptions{}).ok());
interpreter_.reset(new Interpreter(&error_reporter_));
}
~KernelTest() override {
// The data needs to be released before the interpreter because the
// interpreter references the data.
delegate_data_.reset();
interpreter_.reset();
}
template <typename T>
void ConfigureDelegate(T prepare_function) {
delegate_.data_ = delegate_data_.get();
delegate_.data_ = &delegate_data_;
delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors;
delegate_.FreeBufferHandle = nullptr;
delegate_.Prepare = prepare_function;
@ -71,7 +64,7 @@ class KernelTest : public testing::FlexModelTest {
}
private:
std::unique_ptr<DelegateData> delegate_data_;
DelegateData delegate_data_;
TfLiteDelegate delegate_;
};

View File

@ -316,18 +316,13 @@ void BenchmarkTfLiteModel::Init() {
tflite::ops::builtin::BuiltinOpResolver resolver;
#endif
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
const int32_t num_threads = params_.Get<int32_t>("num_threads");
tflite::InterpreterBuilder(*model, resolver)(&interpreter, num_threads);
if (!interpreter) {
TFLITE_LOG(FATAL) << "Failed to construct interpreter";
}
profiling_listener_.SetInterpreter(interpreter.get());
const int32_t num_threads = params_.Get<int32_t>("num_threads");
if (num_threads != -1) {
interpreter->SetNumThreads(num_threads);
}
bool use_nnapi = params_.Get<bool>("use_nnapi");
interpreter->UseNNAPI(use_nnapi);