Wire TFLite thread count to the eager context
PiperOrigin-RevId: 226214825
This commit is contained in:
		
							parent
							
								
									81e98fcb01
								
							
						
					
					
						commit
						2d1a3052e1
					
				| @ -30,6 +30,21 @@ namespace flex { | |||||||
| namespace delegate { | namespace delegate { | ||||||
| 
 | 
 | ||||||
| TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { | TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { | ||||||
|  |   // If the TensorFlow Lite thread count is explicitly configured, use it,
 | ||||||
|  |   // otherwise rely on the default TensorFlow threading behavior.
 | ||||||
|  |   tensorflow::SessionOptions session_options; | ||||||
|  |   if (context->recommended_num_threads > 0) { | ||||||
|  |     session_options.config.set_intra_op_parallelism_threads( | ||||||
|  |         context->recommended_num_threads); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   if (!reinterpret_cast<DelegateData*>(delegate->data_) | ||||||
|  |            ->Prepare(session_options) | ||||||
|  |            .ok()) { | ||||||
|  |     context->ReportError(context, "Failed to initialize TensorFlow context."); | ||||||
|  |     return kTfLiteError; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   // Get the nodes in the current execution plan. Interpreter owns this array.
 |   // Get the nodes in the current execution plan. Interpreter owns this array.
 | ||||||
|   TfLiteIntArray* plan; |   TfLiteIntArray* plan; | ||||||
|   TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); |   TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); | ||||||
| @ -118,20 +133,11 @@ AcquireFlexDelegate() { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<FlexDelegate> FlexDelegate::Create() { | std::unique_ptr<FlexDelegate> FlexDelegate::Create() { | ||||||
|   std::unique_ptr<flex::DelegateData> delegate_data; |   return std::unique_ptr<FlexDelegate>(new FlexDelegate()); | ||||||
|   if (!flex::DelegateData::Create(&delegate_data).ok()) { |  | ||||||
|     fprintf(stderr, "Unable to initialize TensorFlow context.\n"); |  | ||||||
|     return nullptr; |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|   return std::unique_ptr<FlexDelegate>( | FlexDelegate::FlexDelegate() : TfLiteDelegate(TfLiteDelegateCreate()) { | ||||||
|       new FlexDelegate(std::move(delegate_data))); |   data_ = &delegate_data_; | ||||||
| } |  | ||||||
| 
 |  | ||||||
| FlexDelegate::FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data) |  | ||||||
|     : TfLiteDelegate(TfLiteDelegateCreate()), |  | ||||||
|       delegate_data_(std::move(delegate_data)) { |  | ||||||
|   data_ = delegate_data_.get(); |  | ||||||
|   Prepare = &flex::delegate::Prepare; |   Prepare = &flex::delegate::Prepare; | ||||||
|   CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle; |   CopyFromBufferHandle = &flex::delegate::CopyFromBufferHandle; | ||||||
|   flags = kTfLiteDelegateFlagsAllowDynamicTensors; |   flags = kTfLiteDelegateFlagsAllowDynamicTensors; | ||||||
|  | |||||||
| @ -49,9 +49,9 @@ class FlexDelegate : public TfLiteDelegate { | |||||||
|   ~FlexDelegate(); |   ~FlexDelegate(); | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   explicit FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data); |   FlexDelegate(); | ||||||
| 
 | 
 | ||||||
|   std::unique_ptr<flex::DelegateData> delegate_data_; |   flex::DelegateData delegate_data_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace tflite
 | }  // namespace tflite
 | ||||||
|  | |||||||
| @ -20,29 +20,32 @@ limitations under the License. | |||||||
| 
 | 
 | ||||||
| namespace tflite { | namespace tflite { | ||||||
| namespace flex { | namespace flex { | ||||||
| tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) { | DelegateData::DelegateData() {} | ||||||
|  | 
 | ||||||
|  | DelegateData::~DelegateData() {} | ||||||
|  | 
 | ||||||
|  | tensorflow::Status DelegateData::Prepare( | ||||||
|  |     const tensorflow::SessionOptions& session_options) { | ||||||
|  |   if (eager_context_) { | ||||||
|  |     return tensorflow::Status(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   std::vector<std::unique_ptr<tensorflow::Device>> devices; |   std::vector<std::unique_ptr<tensorflow::Device>> devices; | ||||||
| 
 | 
 | ||||||
|   TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( |   TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( | ||||||
|       tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0", |       session_options, "/job:localhost/replica:0/task:0", &devices)); | ||||||
|       &devices)); |  | ||||||
| 
 | 
 | ||||||
|   std::unique_ptr<tensorflow::DeviceMgr> device_mgr = |   std::unique_ptr<tensorflow::DeviceMgr> device_mgr = | ||||||
|       absl::make_unique<tensorflow::DeviceMgr>(std::move(devices)); |       absl::make_unique<tensorflow::DeviceMgr>(std::move(devices)); | ||||||
|   // Note that Rendezvous is ref-counted so it will be automatically deleted.
 |   // Note that Rendezvous is ref-counted so it will be automatically deleted.
 | ||||||
|   tensorflow::Rendezvous* rendezvous = |   tensorflow::Rendezvous* rendezvous = | ||||||
|       new tensorflow::IntraProcessRendezvous(device_mgr.get()); |       new tensorflow::IntraProcessRendezvous(device_mgr.get()); | ||||||
|   data->reset(new DelegateData(new tensorflow::EagerContext( |   eager_context_.reset(new tensorflow::EagerContext( | ||||||
|       tensorflow::SessionOptions(), |       session_options, | ||||||
|       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, |       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, | ||||||
|       /*async=*/false, std::move(device_mgr), rendezvous))); |       /*async=*/false, std::move(device_mgr), rendezvous)); | ||||||
|   return tensorflow::Status(); |   return tensorflow::Status(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| DelegateData::DelegateData(tensorflow::EagerContext* eager_context) |  | ||||||
|     : eager_context_(eager_context) {} |  | ||||||
| 
 |  | ||||||
| DelegateData::~DelegateData() {} |  | ||||||
| 
 |  | ||||||
| }  // namespace flex
 | }  // namespace flex
 | ||||||
| }  // namespace tflite
 | }  // namespace tflite
 | ||||||
|  | |||||||
| @ -15,21 +15,30 @@ limitations under the License. | |||||||
| #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ | #ifndef TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ | ||||||
| #define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ | #define TENSORFLOW_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/lite/delegates/flex/buffer_map.h" |  | ||||||
| #include "tensorflow/core/common_runtime/eager/context.h" | #include "tensorflow/core/common_runtime/eager/context.h" | ||||||
|  | #include "tensorflow/core/public/session_options.h" | ||||||
|  | #include "tensorflow/lite/delegates/flex/buffer_map.h" | ||||||
| 
 | 
 | ||||||
| namespace tflite { | namespace tflite { | ||||||
| namespace flex { | namespace flex { | ||||||
| 
 | 
 | ||||||
| // Data kept by the Flex delegate for the lifetime of an Interpreter.
 | // Data kept by the Flex delegate for the lifetime of an Interpreter.
 | ||||||
|  | //
 | ||||||
|  | // Note: This class is *not* thread-safe; any dependent delegates should not be
 | ||||||
|  | // used concurrently.
 | ||||||
| class DelegateData { | class DelegateData { | ||||||
|  public: |  public: | ||||||
|   // Create a new DelegateData, initialized with a newly-created EagerContext.
 |   DelegateData(); | ||||||
|   static tensorflow::Status Create(std::unique_ptr<DelegateData>* data); |  | ||||||
| 
 |  | ||||||
|   ~DelegateData(); |   ~DelegateData(); | ||||||
| 
 | 
 | ||||||
|  |   // Prepare the necessary EagerContext and data for execution.
 | ||||||
|  |   // This must be called at least once before execution. After preparation
 | ||||||
|  |   // succeeds, redundant calls will be ignored (even if the session_options
 | ||||||
|  |   // differ).
 | ||||||
|  |   tensorflow::Status Prepare(const tensorflow::SessionOptions& session_options); | ||||||
|  | 
 | ||||||
|   // The EagerContext that is required for execution of Flex Ops.
 |   // The EagerContext that is required for execution of Flex Ops.
 | ||||||
|  |   // Note: The context is lazily created after the first call to |Prepare()|.
 | ||||||
|   tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } |   tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } | ||||||
| 
 | 
 | ||||||
|   // Map from TF Lite tensor index to TensorFlow tensor for a given context.
 |   // Map from TF Lite tensor index to TensorFlow tensor for a given context.
 | ||||||
| @ -38,8 +47,7 @@ class DelegateData { | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   explicit DelegateData(tensorflow::EagerContext* eager_context); |   // Will be null until Prepare() is called and completes successfully.
 | ||||||
| 
 |  | ||||||
|   std::unique_ptr<tensorflow::EagerContext> eager_context_; |   std::unique_ptr<tensorflow::EagerContext> eager_context_; | ||||||
|   // TODO(b/112439500): Clean up stale BufferMap instances after adding the
 |   // TODO(b/112439500): Clean up stale BufferMap instances after adding the
 | ||||||
|   // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.
 |   // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate.
 | ||||||
|  | |||||||
| @ -24,18 +24,20 @@ namespace flex { | |||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| TEST(DelegateDataTest, Basic) { | TEST(DelegateDataTest, Basic) { | ||||||
|   std::unique_ptr<DelegateData> data; |   DelegateData data; | ||||||
|   // We only check for success because it is hard to make initialization fail.
 |   // We only check for success because it is hard to make initialization fail.
 | ||||||
|   // It only happens if we manage to not link the CPU device factory into the
 |   // It only happens if we manage to not link the CPU device factory into the
 | ||||||
|   // binary.
 |   // binary.
 | ||||||
|   EXPECT_TRUE(DelegateData::Create(&data).ok()); |   tensorflow::SessionOptions session_options; | ||||||
|  |   session_options.config.set_intra_op_parallelism_threads(2); | ||||||
|  |   EXPECT_TRUE(data.Prepare(session_options).ok()); | ||||||
| 
 | 
 | ||||||
|   TfLiteContext dummy_context1 = {}; |   TfLiteContext dummy_context1 = {}; | ||||||
|   TfLiteContext dummy_context2 = {}; |   TfLiteContext dummy_context2 = {}; | ||||||
|   EXPECT_NE(data->GetEagerContext(), nullptr); |   EXPECT_NE(data.GetEagerContext(), nullptr); | ||||||
|   EXPECT_NE(data->GetBufferMap(&dummy_context1), nullptr); |   EXPECT_NE(data.GetBufferMap(&dummy_context1), nullptr); | ||||||
|   EXPECT_NE(data->GetBufferMap(&dummy_context1), |   EXPECT_NE(data.GetBufferMap(&dummy_context1), | ||||||
|             data->GetBufferMap(&dummy_context2)); |             data.GetBufferMap(&dummy_context2)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
|  | |||||||
| @ -252,6 +252,56 @@ TEST_F(DelegateTest, MultipleInterpretersSameDelegate) { | |||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(DelegateTest, SingleThreaded) { | ||||||
|  |   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); | ||||||
|  |   AddTfOp(testing::kUnpack, {0}, {1, 2}); | ||||||
|  |   AddTfOp(testing::kUnpack, {3}, {4, 5}); | ||||||
|  |   AddTfOp(testing::kAdd, {1, 4}, {6}); | ||||||
|  |   AddTfOp(testing::kAdd, {2, 5}, {7}); | ||||||
|  |   AddTfOp(testing::kMul, {6, 7}, {8}); | ||||||
|  | 
 | ||||||
|  |   // Explicitly disable multi-threading before installing the delegate.
 | ||||||
|  |   interpreter_->SetNumThreads(1); | ||||||
|  |   ConfigureDelegate(); | ||||||
|  | 
 | ||||||
|  |   SetShape(0, {2, 2, 1}); | ||||||
|  |   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); | ||||||
|  |   SetShape(3, {2, 2, 1}); | ||||||
|  |   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); | ||||||
|  | 
 | ||||||
|  |   // Invocation should behave as expected.
 | ||||||
|  |   ASSERT_TRUE(Invoke()); | ||||||
|  | 
 | ||||||
|  |   ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); | ||||||
|  |   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); | ||||||
|  |   ASSERT_EQ(GetType(8), kTfLiteFloat32); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST_F(DelegateTest, MultiThreaded) { | ||||||
|  |   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); | ||||||
|  |   AddTfOp(testing::kUnpack, {0}, {1, 2}); | ||||||
|  |   AddTfOp(testing::kUnpack, {3}, {4, 5}); | ||||||
|  |   AddTfOp(testing::kAdd, {1, 4}, {6}); | ||||||
|  |   AddTfOp(testing::kAdd, {2, 5}, {7}); | ||||||
|  |   AddTfOp(testing::kMul, {6, 7}, {8}); | ||||||
|  | 
 | ||||||
|  |   // Explicitly enable multi-threading before installing the delegate.
 | ||||||
|  |   interpreter_->SetNumThreads(4); | ||||||
|  |   ConfigureDelegate(); | ||||||
|  | 
 | ||||||
|  |   SetShape(0, {2, 2, 1}); | ||||||
|  |   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); | ||||||
|  |   SetShape(3, {2, 2, 1}); | ||||||
|  |   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); | ||||||
|  | 
 | ||||||
|  |   // Invocation should behave as expected.
 | ||||||
|  |   ASSERT_TRUE(Invoke()); | ||||||
|  | 
 | ||||||
|  |   ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); | ||||||
|  |   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); | ||||||
|  |   ASSERT_EQ(GetType(8), kTfLiteFloat32); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| }  // namespace flex
 | }  // namespace flex
 | ||||||
| }  // namespace tflite
 | }  // namespace tflite
 | ||||||
|  | |||||||
| @ -39,20 +39,13 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate, | |||||||
| class KernelTest : public testing::FlexModelTest { | class KernelTest : public testing::FlexModelTest { | ||||||
|  public: |  public: | ||||||
|   KernelTest() { |   KernelTest() { | ||||||
|     CHECK(DelegateData::Create(&delegate_data_).ok()); |     CHECK(delegate_data_.Prepare(tensorflow::SessionOptions{}).ok()); | ||||||
|     interpreter_.reset(new Interpreter(&error_reporter_)); |     interpreter_.reset(new Interpreter(&error_reporter_)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   ~KernelTest() override { |  | ||||||
|     // The data needs to be released before the interpreter because the
 |  | ||||||
|     // interpreter references the data.
 |  | ||||||
|     delegate_data_.reset(); |  | ||||||
|     interpreter_.reset(); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   template <typename T> |   template <typename T> | ||||||
|   void ConfigureDelegate(T prepare_function) { |   void ConfigureDelegate(T prepare_function) { | ||||||
|     delegate_.data_ = delegate_data_.get(); |     delegate_.data_ = &delegate_data_; | ||||||
|     delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors; |     delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors; | ||||||
|     delegate_.FreeBufferHandle = nullptr; |     delegate_.FreeBufferHandle = nullptr; | ||||||
|     delegate_.Prepare = prepare_function; |     delegate_.Prepare = prepare_function; | ||||||
| @ -71,7 +64,7 @@ class KernelTest : public testing::FlexModelTest { | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   std::unique_ptr<DelegateData> delegate_data_; |   DelegateData delegate_data_; | ||||||
|   TfLiteDelegate delegate_; |   TfLiteDelegate delegate_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -316,18 +316,13 @@ void BenchmarkTfLiteModel::Init() { | |||||||
|   tflite::ops::builtin::BuiltinOpResolver resolver; |   tflite::ops::builtin::BuiltinOpResolver resolver; | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
|   tflite::InterpreterBuilder(*model, resolver)(&interpreter); |   const int32_t num_threads = params_.Get<int32_t>("num_threads"); | ||||||
|  |   tflite::InterpreterBuilder(*model, resolver)(&interpreter, num_threads); | ||||||
|   if (!interpreter) { |   if (!interpreter) { | ||||||
|     TFLITE_LOG(FATAL) << "Failed to construct interpreter"; |     TFLITE_LOG(FATAL) << "Failed to construct interpreter"; | ||||||
|   } |   } | ||||||
|   profiling_listener_.SetInterpreter(interpreter.get()); |   profiling_listener_.SetInterpreter(interpreter.get()); | ||||||
| 
 | 
 | ||||||
|   const int32_t num_threads = params_.Get<int32_t>("num_threads"); |  | ||||||
| 
 |  | ||||||
|   if (num_threads != -1) { |  | ||||||
|     interpreter->SetNumThreads(num_threads); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   bool use_nnapi = params_.Get<bool>("use_nnapi"); |   bool use_nnapi = params_.Get<bool>("use_nnapi"); | ||||||
| 
 | 
 | ||||||
|   interpreter->UseNNAPI(use_nnapi); |   interpreter->UseNNAPI(use_nnapi); | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user