Move FlatBuffer to GraphFloat32 converter into the common directory.
PiperOrigin-RevId: 306739537 Change-Id: I055e864b5c25afc75b45d3e067b56d38214ce99e
This commit is contained in:
parent
4af24c8817
commit
2c8d58c57f
@ -7,21 +7,11 @@ cc_binary(
|
||||
name = "performance_profiling",
|
||||
srcs = ["performance_profiling.cc"],
|
||||
deps = [
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/delegates/gpu/cl:cl_command_queue",
|
||||
"//tensorflow/lite/delegates/gpu/cl:environment",
|
||||
"//tensorflow/lite/delegates/gpu/cl:inference_context",
|
||||
"//tensorflow/lite/delegates/gpu/cl:model_hints",
|
||||
"//tensorflow/lite/delegates/gpu/cl:opencl_wrapper",
|
||||
"//tensorflow/lite/delegates/gpu/cl:precision",
|
||||
"//tensorflow/lite/delegates/gpu/cl:tensor_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/delegates/gpu/common:tfl2model",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
||||
|
@ -13,112 +13,25 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/environment.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/inference_context.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/model_hints.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tfl2model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
class DelegateContext {
|
||||
public:
|
||||
bool Init(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* delegate_params) {
|
||||
auto denormalized_graph =
|
||||
reinterpret_cast<GraphFloat32*>(delegate_params->delegate->data_);
|
||||
absl::Status status =
|
||||
BuildModel(context, delegate_params, denormalized_graph);
|
||||
if (!status.ok()) {
|
||||
context->ReportError(context, "Failed to convert a model: %s",
|
||||
std::string(status.message()).c_str());
|
||||
}
|
||||
return status.ok();
|
||||
}
|
||||
};
|
||||
|
||||
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
const TfLiteRegistration kRegistration = {
|
||||
.init = [](TfLiteContext* context, const char* buffer, size_t) -> void* {
|
||||
auto* delegate_context = new DelegateContext();
|
||||
if (!delegate_context->Init(
|
||||
context,
|
||||
reinterpret_cast<const TfLiteDelegateParams*>(buffer))) {
|
||||
delete delegate_context;
|
||||
return nullptr;
|
||||
}
|
||||
return delegate_context;
|
||||
},
|
||||
.free = [](TfLiteContext* context, void* buffer) -> void {
|
||||
delete reinterpret_cast<DelegateContext*>(buffer);
|
||||
},
|
||||
.prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
|
||||
return node->user_data ? kTfLiteOk : kTfLiteError;
|
||||
},
|
||||
.invoke = nullptr,
|
||||
};
|
||||
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||
const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||
context, kRegistration, ops_to_replace, delegate);
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
return status;
|
||||
}
|
||||
|
||||
absl::Status FlatBufferToGPUGraph(
|
||||
const std::unique_ptr<tflite::FlatBufferModel>& flatbuffer,
|
||||
GraphFloat32* graph) {
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
tflite::InterpreterBuilder interpreter_builder(*flatbuffer, op_resolver);
|
||||
if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
|
||||
return absl::InternalError("Unable to prepare TfLite interpreter.");
|
||||
}
|
||||
interpreter->UseNNAPI(false);
|
||||
TfLiteDelegate delegate;
|
||||
delegate.data_ = graph;
|
||||
delegate.flags = kTfLiteDelegateFlagsNone;
|
||||
delegate.Prepare = DelegatePrepare;
|
||||
delegate.CopyFromBufferHandle = nullptr;
|
||||
delegate.CopyToBufferHandle = nullptr;
|
||||
delegate.FreeBufferHandle = nullptr;
|
||||
|
||||
if (interpreter->ModifyGraphWithDelegate(&delegate) != kTfLiteOk) {
|
||||
return absl::InternalError("Conversion from TfLite model failed.");
|
||||
}
|
||||
|
||||
NullTransformationReporter reporter;
|
||||
ModelTransformer transformer(graph, &reporter);
|
||||
if (!ApplyGeneralTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph general transformations failed");
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status RunModelSample(const std::string& model_name) {
|
||||
auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str());
|
||||
GraphFloat32 graph_cl;
|
||||
RETURN_IF_ERROR(FlatBufferToGPUGraph(flatbuffer, &graph_cl));
|
||||
RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, &graph_cl));
|
||||
|
||||
Environment env;
|
||||
RETURN_IF_ERROR(CreateEnvironment(&env));
|
||||
|
@ -216,6 +216,32 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfl2model",
|
||||
srcs = ["tfl2model.cc"],
|
||||
hdrs = ["tfl2model.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:framework_lib",
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/delegates/gpu/cl:api",
|
||||
"//tensorflow/lite/delegates/gpu/cl:opencl_wrapper",
|
||||
"//tensorflow/lite/delegates/gpu/cl:tensor_type_util",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
|
||||
"//tensorflow/lite/delegates/gpu/gl:api2",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "types",
|
||||
hdrs = ["types.h"],
|
||||
|
104
tensorflow/lite/delegates/gpu/common/tfl2model.cc
Normal file
104
tensorflow/lite/delegates/gpu/common/tfl2model.cc
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/gpu/common/tfl2model.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
class DelegateContext {
|
||||
public:
|
||||
bool Init(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* delegate_params) {
|
||||
auto denormalized_graph =
|
||||
reinterpret_cast<GraphFloat32*>(delegate_params->delegate->data_);
|
||||
return denormalized_graph
|
||||
? BuildModel(context, delegate_params, denormalized_graph).ok()
|
||||
: false;
|
||||
}
|
||||
};
|
||||
|
||||
TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
const TfLiteRegistration kRegistration = {
|
||||
.init = [](TfLiteContext* context, const char* buffer, size_t) -> void* {
|
||||
auto* delegate_context = new DelegateContext();
|
||||
if (!delegate_context->Init(
|
||||
context,
|
||||
reinterpret_cast<const TfLiteDelegateParams*>(buffer))) {
|
||||
delete delegate_context;
|
||||
return nullptr;
|
||||
}
|
||||
return delegate_context;
|
||||
},
|
||||
.free = [](TfLiteContext* context, void* buffer) -> void {
|
||||
delete reinterpret_cast<DelegateContext*>(buffer);
|
||||
},
|
||||
.prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
|
||||
return node->user_data ? kTfLiteOk : kTfLiteError;
|
||||
},
|
||||
.invoke = nullptr,
|
||||
};
|
||||
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||
const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||
context, kRegistration, ops_to_replace, delegate);
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
return status;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
|
||||
GraphFloat32* graph) {
|
||||
ops::builtin::BuiltinOpResolver op_resolver;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver);
|
||||
if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
|
||||
return absl::InternalError("Unable to prepare TfLite interpreter.");
|
||||
}
|
||||
interpreter->UseNNAPI(false);
|
||||
TfLiteDelegate delegate;
|
||||
delegate.data_ = graph;
|
||||
delegate.flags = kTfLiteDelegateFlagsNone;
|
||||
delegate.Prepare = DelegatePrepare;
|
||||
delegate.CopyFromBufferHandle = nullptr;
|
||||
delegate.CopyToBufferHandle = nullptr;
|
||||
delegate.FreeBufferHandle = nullptr;
|
||||
|
||||
if (interpreter->ModifyGraphWithDelegate(&delegate) != kTfLiteOk) {
|
||||
return absl::InternalError("Conversion from TfLite model failed.");
|
||||
}
|
||||
|
||||
NullTransformationReporter reporter;
|
||||
ModelTransformer transformer(graph, &reporter);
|
||||
if (!ApplyGeneralTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph general transformations failed");
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
32
tensorflow/lite/delegates/gpu/common/tfl2model.h
Normal file
32
tensorflow/lite/delegates/gpu/common/tfl2model.h
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TFL2MODEL_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TFL2MODEL_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
// Generates GraphFloat32 basing on the FlatBufferModel without specifying a
|
||||
// delegate.
|
||||
absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
|
||||
GraphFloat32* graph);
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TFL2MODEL_H_
|
Loading…
x
Reference in New Issue
Block a user