diff --git a/tensorflow/lite/delegates/flex/delegate.cc b/tensorflow/lite/delegates/flex/delegate.cc index ca7314fbaee..dcf5b795d82 100644 --- a/tensorflow/lite/delegates/flex/delegate.cc +++ b/tensorflow/lite/delegates/flex/delegate.cc @@ -30,6 +30,21 @@ namespace flex { namespace delegate { TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // If the TensorFlow Lite thread count is explicitly configured, use it, + // otherwise rely on the default TensorFlow threading behavior. + tensorflow::SessionOptions session_options; + if (context->recommended_num_threads > 0) { + session_options.config.set_intra_op_parallelism_threads( + context->recommended_num_threads); + } + + if (!reinterpret_cast(delegate->data_) + ->Prepare(session_options) + .ok()) { + context->ReportError(context, "Failed to initialize TensorFlow context."); + return kTfLiteError; + } + // Get the nodes in the current execution plan. Interpreter owns this array. TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); @@ -118,20 +133,11 @@ AcquireFlexDelegate() { } std::unique_ptr FlexDelegate::Create() { - std::unique_ptr delegate_data; - if (!flex::DelegateData::Create(&delegate_data).ok()) { - fprintf(stderr, "Unable to initialize TensorFlow context.\n"); - return nullptr; - } - - return std::unique_ptr( - new FlexDelegate(std::move(delegate_data))); + return std::unique_ptr(new FlexDelegate()); } -FlexDelegate::FlexDelegate(std::unique_ptr delegate_data) - : TfLiteDelegate(TfLiteDelegateCreate()), - delegate_data_(std::move(delegate_data)) { - data_ = delegate_data_.get(); +FlexDelegate::FlexDelegate() : TfLiteDelegate(TfLiteDelegateCreate()) { + data_ = &delegate_data_; Prepare = &flex::delegate::Prepare; CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle; flags = kTfLiteDelegateFlagsAllowDynamicTensors; diff --git a/tensorflow/lite/delegates/flex/delegate.h b/tensorflow/lite/delegates/flex/delegate.h index 018ff3e0b0e..767cbe13c4e 100644 --- a/tensorflow/lite/delegates/flex/delegate.h +++ b/tensorflow/lite/delegates/flex/delegate.h @@ -49,9 +49,9 @@ class FlexDelegate : public TfLiteDelegate { ~FlexDelegate(); private: - explicit FlexDelegate(std::unique_ptr delegate_data); + FlexDelegate(); - std::unique_ptr delegate_data_; + flex::DelegateData delegate_data_; }; } // namespace tflite diff --git a/tensorflow/lite/delegates/flex/delegate_data.cc b/tensorflow/lite/delegates/flex/delegate_data.cc index 1483a530388..87f37697468 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.cc +++ b/tensorflow/lite/delegates/flex/delegate_data.cc @@ -20,29 +20,32 @@ limitations under the License. namespace tflite { namespace flex { -tensorflow::Status DelegateData::Create(std::unique_ptr* data) { +DelegateData::DelegateData() {} + +DelegateData::~DelegateData() {} + +tensorflow::Status DelegateData::Prepare( + const tensorflow::SessionOptions& session_options) { + if (eager_context_) { + return tensorflow::Status(); + } + std::vector> devices; TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( - tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0", - &devices)); + session_options, "/job:localhost/replica:0/task:0", &devices)); std::unique_ptr device_mgr = absl::make_unique(std::move(devices)); // Note that Rendezvous is ref-counted so it will be automatically deleted. tensorflow::Rendezvous* rendezvous = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - data->reset(new DelegateData(new tensorflow::EagerContext( - tensorflow::SessionOptions(), + eager_context_.reset(new tensorflow::EagerContext( + session_options, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - /*async=*/false, std::move(device_mgr), rendezvous))); + /*async=*/false, std::move(device_mgr), rendezvous)); return tensorflow::Status(); } -DelegateData::DelegateData(tensorflow::EagerContext* eager_context) - : eager_context_(eager_context) {} - -DelegateData::~DelegateData() {} - } // namespace flex } // namespace tflite diff --git a/tensorflow/lite/delegates/flex/delegate_data.h b/tensorflow/lite/delegates/flex/delegate_data.h index a88cc98d03c..20d6b40a5d2 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.h +++ b/tensorflow/lite/delegates/flex/delegate_data.h @@ -15,21 +15,30 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ #define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ -#include "tensorflow/lite/delegates/flex/buffer_map.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/lite/delegates/flex/buffer_map.h" namespace tflite { namespace flex { // Data kept by the Flex delegate for the lifetime of an Interpreter. +// +// Note: This class is *not* thread-safe; any dependent delegates should not be +// used concurrently. class DelegateData { public: - // Create a new DelegateData, initialized with a newly-created EagerContext. - static tensorflow::Status Create(std::unique_ptr* data); - + DelegateData(); ~DelegateData(); + // Prepare the necessary EagerContext and data for execution. + // This must be called at least once before execution. After preparation + // succeeds, redundant calls will be ignored (even if the session_options + // differ). + tensorflow::Status Prepare(const tensorflow::SessionOptions& session_options); + // The EagerContext that is required for execution of Flex Ops. + // Note: The context is lazily created after the first call to |Prepare()|. tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } // Map from TF Lite tensor index to TensorFlow tensor for a given context. @@ -38,8 +47,7 @@ class DelegateData { } private: - explicit DelegateData(tensorflow::EagerContext* eager_context); - + // Will be null until Prepare() is called and completes successfully. std::unique_ptr eager_context_; // TODO(b/112439500): Clean up stale BufferMap instances after adding the // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate. diff --git a/tensorflow/lite/delegates/flex/delegate_data_test.cc b/tensorflow/lite/delegates/flex/delegate_data_test.cc index cd274e7cb1c..22b8e436fb5 100644 --- a/tensorflow/lite/delegates/flex/delegate_data_test.cc +++ b/tensorflow/lite/delegates/flex/delegate_data_test.cc @@ -24,18 +24,20 @@ namespace flex { namespace { TEST(DelegateDataTest, Basic) { - std::unique_ptr data; + DelegateData data; // We only check for success because it is hard to make initialization fail. // It only happens if we manage to not link the CPU device factory into the // binary. - EXPECT_TRUE(DelegateData::Create(&data).ok()); + tensorflow::SessionOptions session_options; + session_options.config.set_intra_op_parallelism_threads(2); + EXPECT_TRUE(data.Prepare(session_options).ok()); TfLiteContext dummy_context1 = {}; TfLiteContext dummy_context2 = {}; - EXPECT_NE(data->GetEagerContext(), nullptr); - EXPECT_NE(data->GetBufferMap(&dummy_context1), nullptr); - EXPECT_NE(data->GetBufferMap(&dummy_context1), - data->GetBufferMap(&dummy_context2)); + EXPECT_NE(data.GetEagerContext(), nullptr); + EXPECT_NE(data.GetBufferMap(&dummy_context1), nullptr); + EXPECT_NE(data.GetBufferMap(&dummy_context1), + data.GetBufferMap(&dummy_context2)); } } // namespace diff --git a/tensorflow/lite/delegates/flex/delegate_test.cc b/tensorflow/lite/delegates/flex/delegate_test.cc index ee37090d94e..b48fe181e1f 100644 --- a/tensorflow/lite/delegates/flex/delegate_test.cc +++ b/tensorflow/lite/delegates/flex/delegate_test.cc @@ -252,6 +252,56 @@ TEST_F(DelegateTest, MultipleInterpretersSameDelegate) { } } +TEST_F(DelegateTest, SingleThreaded) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Explicitly disable multi-threading before installing the delegate. + interpreter_->SetNumThreads(1); + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + // Invocation should behave as expected. + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + ASSERT_EQ(GetType(8), kTfLiteFloat32); +} + +TEST_F(DelegateTest, MultiThreaded) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Explicitly enable multi-threading before installing the delegate. + interpreter_->SetNumThreads(4); + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + // Invocation should behave as expected. + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + ASSERT_EQ(GetType(8), kTfLiteFloat32); +} + } // namespace } // namespace flex } // namespace tflite diff --git a/tensorflow/lite/delegates/flex/kernel_test.cc b/tensorflow/lite/delegates/flex/kernel_test.cc index efb7300b0bd..cc5c8b32a01 100644 --- a/tensorflow/lite/delegates/flex/kernel_test.cc +++ b/tensorflow/lite/delegates/flex/kernel_test.cc @@ -39,20 +39,13 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate, class KernelTest : public testing::FlexModelTest { public: KernelTest() { - CHECK(DelegateData::Create(&delegate_data_).ok()); + CHECK(delegate_data_.Prepare(tensorflow::SessionOptions{}).ok()); interpreter_.reset(new Interpreter(&error_reporter_)); } - ~KernelTest() override { - // The data needs to be released before the interpreter because the - // interpreter references the data. - delegate_data_.reset(); - interpreter_.reset(); - } - template void ConfigureDelegate(T prepare_function) { - delegate_.data_ = delegate_data_.get(); + delegate_.data_ = &delegate_data_; delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors; delegate_.FreeBufferHandle = nullptr; delegate_.Prepare = prepare_function; @@ -71,7 +64,7 @@ class KernelTest : public testing::FlexModelTest { } private: - std::unique_ptr delegate_data_; + DelegateData delegate_data_; TfLiteDelegate delegate_; }; diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 32cf4e4292a..0bc7565e82c 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -316,18 +316,13 @@ void BenchmarkTfLiteModel::Init() { tflite::ops::builtin::BuiltinOpResolver resolver; #endif - tflite::InterpreterBuilder(*model, resolver)(&interpreter); + const int32_t num_threads = params_.Get("num_threads"); + tflite::InterpreterBuilder(*model, resolver)(&interpreter, num_threads); if (!interpreter) { TFLITE_LOG(FATAL) << "Failed to construct interpreter"; } profiling_listener_.SetInterpreter(interpreter.get()); - const int32_t num_threads = params_.Get("num_threads"); - - if (num_threads != -1) { - interpreter->SetNumThreads(num_threads); - } - bool use_nnapi = params_.Get("use_nnapi"); interpreter->UseNNAPI(use_nnapi);