Usually, flex delegate is enabled by symbol override of AcquireFlexDelegate() function. But this approach doesn't work well with shared library. Since pywrap_tensorflow_internal.so is available for tensorflow PIP, I've made the following changes to enable flex delegate. - Included flex delegate module to the pywrap_tensorflow_internal.so. This file already contains most TF internal logic and having TFLite flex delegate impacts about 72K to the output. - Added new function of TF_AcquireFlexDelegate() in the delegate module. - Updated logic in AcquireFlexDelegate() of interpreter_builder.cc to check the availability of pywrap_tensorflow_internal.so and lookup the TF_AcquireFlexDelegate() symbol to enable flex delegate. Also updated python/lite_flex_test.py since flex delegate is supported with Python API PiperOrigin-RevId: 317044994 Change-Id: Ic5e953f4a675b3f5360a4c7d607568193103711a
146 lines
5.1 KiB
C++
146 lines
5.1 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#include "tensorflow/lite/delegates/flex/delegate.h"
|
|
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "absl/strings/str_cat.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
#include "tensorflow/lite/context_util.h"
|
|
#include "tensorflow/lite/delegates/flex/buffer_map.h"
|
|
#include "tensorflow/lite/delegates/flex/kernel.h"
|
|
#include "tensorflow/lite/delegates/flex/util.h"
|
|
#include "tensorflow/lite/minimal_logging.h"
|
|
#include "tensorflow/lite/string_util.h"
|
|
#include "tensorflow/lite/util.h"
|
|
|
|
namespace tflite {
|
|
|
|
// Corresponding weak declaration found in lite/interpreter_builder.cc.
|
|
TfLiteDelegateUniquePtr AcquireFlexDelegate() {
|
|
return tflite::FlexDelegate::Create();
|
|
}
|
|
|
|
TfLiteDelegateUniquePtr FlexDelegate::Create(
|
|
std::unique_ptr<FlexDelegate> base_delegate) {
|
|
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO,
|
|
"Created TensorFlow Lite delegate for select TF ops.");
|
|
if (base_delegate == nullptr) {
|
|
base_delegate.reset(new FlexDelegate());
|
|
}
|
|
auto flex_delegate = TfLiteDelegateFactory::Create(std::move(base_delegate));
|
|
flex_delegate->CopyFromBufferHandle =
|
|
[](TfLiteContext* context, TfLiteDelegate* delegate,
|
|
TfLiteBufferHandle buffer_handle,
|
|
TfLiteTensor* tensor) -> TfLiteStatus {
|
|
return reinterpret_cast<FlexDelegate*>(delegate->data_)
|
|
->CopyFromBufferHandle(context, buffer_handle, tensor);
|
|
};
|
|
flex_delegate->flags |= kTfLiteDelegateFlagsAllowDynamicTensors;
|
|
return flex_delegate;
|
|
}
|
|
|
|
TfLiteStatus FlexDelegate::Initialize(TfLiteContext* context) {
|
|
// If the TensorFlow Lite thread count is explicitly configured, use it,
|
|
// otherwise rely on the default TensorFlow threading behavior.
|
|
tensorflow::SessionOptions session_options;
|
|
if (context->recommended_num_threads > 0) {
|
|
session_options.config.set_intra_op_parallelism_threads(
|
|
context->recommended_num_threads);
|
|
}
|
|
|
|
auto status = delegate_data_.Prepare(session_options);
|
|
if (!status.ok()) {
|
|
context->ReportError(context, "Failed to initialize TensorFlow context: %s",
|
|
status.error_message().c_str());
|
|
return kTfLiteError;
|
|
}
|
|
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
const char* FlexDelegate::Name() const {
|
|
static constexpr char kName[] = "TfLiteFlexDelegate";
|
|
return kName;
|
|
}
|
|
|
|
bool FlexDelegate::IsNodeSupportedByDelegate(
|
|
const TfLiteRegistration* registration, const TfLiteNode* node,
|
|
TfLiteContext* context) const {
|
|
return IsFlexOp(registration->custom_name);
|
|
}
|
|
|
|
std::unique_ptr<SimpleDelegateKernelInterface>
|
|
FlexDelegate::CreateDelegateKernelInterface() {
|
|
return std::unique_ptr<SimpleDelegateKernelInterface>(
|
|
new tflite::flex::DelegateKernel());
|
|
}
|
|
|
|
TfLiteStatus FlexDelegate::CopyFromBufferHandle(
|
|
TfLiteContext* context, TfLiteBufferHandle buffer_handle,
|
|
TfLiteTensor* output) {
|
|
flex::BufferMap* buffer_map = delegate_data_.GetBufferMap(context);
|
|
|
|
if (!buffer_map->HasTensor(buffer_handle)) {
|
|
context->ReportError(context, "Invalid tensor index %d.", buffer_handle);
|
|
return kTfLiteError;
|
|
}
|
|
|
|
tensorflow::Tensor t = buffer_map->GetTensor(buffer_handle);
|
|
|
|
if (output->type == kTfLiteString) {
|
|
if (t.dtype() != tensorflow::DT_STRING) {
|
|
context->ReportError(context,
|
|
"Inconsistent type for TF string tensor index %d.",
|
|
buffer_handle);
|
|
return kTfLiteError;
|
|
}
|
|
DynamicBuffer dynamic_buffer;
|
|
|
|
auto tf_data = t.flat<tensorflow::tstring>();
|
|
for (int i = 0; i < t.NumElements(); ++i) {
|
|
dynamic_buffer.AddString(tf_data(i).data(), tf_data(i).size());
|
|
}
|
|
|
|
dynamic_buffer.WriteToTensor(output, /*new_shape=*/nullptr);
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
tensorflow::StringPiece t_data = t.tensor_data();
|
|
|
|
if (output->bytes != t_data.size()) {
|
|
context->ReportError(context,
|
|
absl::StrCat("The given ", output->bytes,
|
|
" bytes are not enough to store "
|
|
"TensorFlow's aligned buffer of size ",
|
|
t_data.size(), " bytes.")
|
|
.c_str());
|
|
return kTfLiteError;
|
|
}
|
|
|
|
memcpy(output->data.raw, t_data.data(), t_data.size());
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
} // namespace tflite
|
|
|
|
// Exported C interface function which is used by AcquireFlexDelegate() at
|
|
// interpreter_build.cc. To export the function name globally, the function name
|
|
// must be matched with patterns in tf_version_script.lds
|
|
extern "C" tflite::TfLiteDelegateUniquePtr TF_AcquireFlexDelegate() {
|
|
return tflite::AcquireFlexDelegate();
|
|
}
|