Internal API extended to support serialized model building/loading.
PiperOrigin-RevId: 337211816 Change-Id: Idda1c5e23c9ec218f7645a806d27071551e82f0a
This commit is contained in:
parent
5922b1e601
commit
fdec32072d
tensorflow/lite/delegates/gpu/cl
@ -570,6 +570,56 @@ TensorObjectDef TensorToDef(const Tensor& tensor) {
|
||||
return def;
|
||||
}
|
||||
|
||||
CalculationsPrecision GetPrecision(const Environment& env,
|
||||
const InferenceOptions& options) {
|
||||
CalculationsPrecision precision;
|
||||
switch (GetPosition(options, InferencePriority::MAX_PRECISION)) {
|
||||
case 1:
|
||||
precision = CalculationsPrecision::F32;
|
||||
break;
|
||||
case 2:
|
||||
precision = CalculationsPrecision::F32_F16;
|
||||
break;
|
||||
case 3:
|
||||
precision = CalculationsPrecision::F16;
|
||||
break;
|
||||
default:
|
||||
precision = CalculationsPrecision::F16;
|
||||
break;
|
||||
}
|
||||
// Increase precision if lower precision is not supported.
|
||||
if (!env.IsSupported(precision)) {
|
||||
precision = CalculationsPrecision::F32_F16;
|
||||
if (!env.IsSupported(precision)) {
|
||||
precision = CalculationsPrecision::F32;
|
||||
}
|
||||
}
|
||||
return precision;
|
||||
}
|
||||
|
||||
TensorStorageType GetStorageTypeFromOptions(const Environment& env,
|
||||
const InferenceOptions& options) {
|
||||
// Fallback to BUFFER that should be supported by default.
|
||||
std::vector<TensorStorageType> preferred_storage_types;
|
||||
if (GetRelativeImportance(options, InferencePriority::MIN_LATENCY,
|
||||
InferencePriority::MIN_MEMORY_USAGE) ==
|
||||
PriorityImportance::HIGHER) {
|
||||
preferred_storage_types = {GetFastestStorageType(env.device().GetInfo()),
|
||||
TensorStorageType::BUFFER};
|
||||
} else {
|
||||
preferred_storage_types = {
|
||||
GetStorageTypeWithMinimalMemoryConsumption(env.device().GetInfo()),
|
||||
TensorStorageType::BUFFER};
|
||||
}
|
||||
|
||||
for (TensorStorageType storage_type : preferred_storage_types) {
|
||||
if (env.IsSupported(storage_type)) {
|
||||
return storage_type;
|
||||
}
|
||||
}
|
||||
return TensorStorageType::UNKNOWN;
|
||||
}
|
||||
|
||||
class InferenceBuilderImpl : public InferenceBuilder {
|
||||
public:
|
||||
explicit InferenceBuilderImpl(Environment* environment)
|
||||
@ -580,8 +630,9 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
||||
const GraphFloat32& graph) {
|
||||
context_ = absl::make_unique<InferenceContext>();
|
||||
InferenceContext::CreateInferenceInfo create_info;
|
||||
create_info.precision = GetPrecision(options);
|
||||
create_info.storage_type = GetStorageType(options);
|
||||
create_info.precision = GetPrecision(*environment_, options);
|
||||
create_info.storage_type =
|
||||
GetStorageTypeFromOptions(*environment_, options);
|
||||
if (options.usage == InferenceUsage::FAST_SINGLE_ANSWER) {
|
||||
create_info.hints.Add(ModelHints::kReduceKernelsCount);
|
||||
create_info.hints.Add(ModelHints::kFastTuning);
|
||||
@ -590,6 +641,30 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
||||
}
|
||||
RETURN_IF_ERROR(context_->InitFromGraph(create_info, graph, environment_));
|
||||
|
||||
#ifdef CL_DELEGATE_ALLOW_GL
|
||||
if (env_options.IsGlAware() &&
|
||||
IsGlSharingSupported(environment_->device())) {
|
||||
gl_interop_fabric_ = absl::make_unique<GlInteropFabric>(
|
||||
env_options.egl_display, environment_);
|
||||
}
|
||||
tie_factory_ = absl::make_unique<TensorTieFactory>(
|
||||
environment_, context_.get(), gl_interop_fabric_.get());
|
||||
#else
|
||||
tie_factory_ =
|
||||
absl::make_unique<TensorTieFactory>(environment_, context_.get());
|
||||
#endif
|
||||
|
||||
inputs_ = LinkTensors(context_->GetInputIds(), AccessType::READ);
|
||||
outputs_ = LinkTensors(context_->GetOutputIds(), AccessType::WRITE);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Initialize(const InferenceEnvironmentOptions& env_options,
|
||||
const std::vector<uint8_t>& serialized_model) {
|
||||
context_ = absl::make_unique<InferenceContext>();
|
||||
RETURN_IF_ERROR(
|
||||
context_->RestoreDeserialized(serialized_model, environment_));
|
||||
|
||||
#ifdef CL_DELEGATE_ALLOW_GL
|
||||
if (env_options.IsGlAware() &&
|
||||
IsGlSharingSupported(environment_->device())) {
|
||||
@ -671,55 +746,6 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
||||
}
|
||||
|
||||
private:
|
||||
TensorStorageType GetStorageType(const InferenceOptions& options) const {
|
||||
// Fallback to BUFFER that should be supported by default.
|
||||
std::vector<TensorStorageType> preferred_storage_types;
|
||||
if (GetRelativeImportance(options, InferencePriority::MIN_LATENCY,
|
||||
InferencePriority::MIN_MEMORY_USAGE) ==
|
||||
PriorityImportance::HIGHER) {
|
||||
preferred_storage_types = {
|
||||
GetFastestStorageType(environment_->device().GetInfo()),
|
||||
TensorStorageType::BUFFER};
|
||||
} else {
|
||||
preferred_storage_types = {GetStorageTypeWithMinimalMemoryConsumption(
|
||||
environment_->device().GetInfo()),
|
||||
TensorStorageType::BUFFER};
|
||||
}
|
||||
|
||||
for (TensorStorageType storage_type : preferred_storage_types) {
|
||||
if (environment_->IsSupported(storage_type)) {
|
||||
return storage_type;
|
||||
}
|
||||
}
|
||||
return TensorStorageType::UNKNOWN;
|
||||
}
|
||||
|
||||
CalculationsPrecision GetPrecision(const InferenceOptions& options) const {
|
||||
CalculationsPrecision precision;
|
||||
switch (GetPosition(options, InferencePriority::MAX_PRECISION)) {
|
||||
case 1:
|
||||
precision = CalculationsPrecision::F32;
|
||||
break;
|
||||
case 2:
|
||||
precision = CalculationsPrecision::F32_F16;
|
||||
break;
|
||||
case 3:
|
||||
precision = CalculationsPrecision::F16;
|
||||
break;
|
||||
default:
|
||||
precision = CalculationsPrecision::F16;
|
||||
break;
|
||||
}
|
||||
// Increase precision if lower precision is not supported.
|
||||
if (!environment_->IsSupported(precision)) {
|
||||
precision = CalculationsPrecision::F32_F16;
|
||||
if (!environment_->IsSupported(precision)) {
|
||||
precision = CalculationsPrecision::F32;
|
||||
}
|
||||
}
|
||||
return precision;
|
||||
}
|
||||
|
||||
// Links internal tensors with external user-facing objects.
|
||||
std::vector<TensorTieDef> LinkTensors(const std::vector<ValueId>& ids,
|
||||
AccessType access) {
|
||||
@ -840,6 +866,39 @@ class InferenceEnvironmentImpl : public InferenceEnvironment {
|
||||
return environment_.Init();
|
||||
}
|
||||
|
||||
absl::Status BuildSerializedModel(
|
||||
const InferenceOptions& options, GraphFloat32 model,
|
||||
std::vector<uint8_t>* serialized_model) final {
|
||||
if (!IsValid(options)) {
|
||||
return absl::InvalidArgumentError("InferenceOptions are invalid.");
|
||||
}
|
||||
InferenceOptions resolved_options = options;
|
||||
ResolveAutoPriority(&resolved_options);
|
||||
if (environment_.program_cache() &&
|
||||
!options_.serialized_binary_cache.empty()) {
|
||||
// Ignore returned error. Cache is discarded.
|
||||
environment_.program_cache()
|
||||
->AddSerializedCache(environment_.context(), environment_.device(),
|
||||
options_.serialized_binary_cache)
|
||||
.IgnoreError();
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(RunGraphTransforms(&model));
|
||||
InferenceContext context;
|
||||
InferenceContext::CreateInferenceInfo create_info;
|
||||
create_info.precision = GetPrecision(environment_, options);
|
||||
create_info.storage_type = GetStorageTypeFromOptions(environment_, options);
|
||||
if (options.usage == InferenceUsage::FAST_SINGLE_ANSWER) {
|
||||
create_info.hints.Add(ModelHints::kReduceKernelsCount);
|
||||
create_info.hints.Add(ModelHints::kFastTuning);
|
||||
} else if (options.usage == InferenceUsage::SUSTAINED_SPEED) {
|
||||
create_info.hints.Add(ModelHints::kAllowSpecialKernels);
|
||||
}
|
||||
RETURN_IF_ERROR(context.InitFromGraph(create_info, model, &environment_,
|
||||
serialized_model));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status NewInferenceBuilder(
|
||||
const InferenceOptions& options, GraphFloat32 model,
|
||||
std::unique_ptr<InferenceBuilder>* builder) final {
|
||||
@ -865,6 +924,24 @@ class InferenceEnvironmentImpl : public InferenceEnvironment {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status NewInferenceBuilder(
|
||||
const std::vector<uint8_t>& serialized_model,
|
||||
std::unique_ptr<InferenceBuilder>* builder) final {
|
||||
if (environment_.program_cache() &&
|
||||
!options_.serialized_binary_cache.empty()) {
|
||||
// Ignore returned error. Cache is discarded.
|
||||
environment_.program_cache()
|
||||
->AddSerializedCache(environment_.context(), environment_.device(),
|
||||
options_.serialized_binary_cache)
|
||||
.IgnoreError();
|
||||
}
|
||||
|
||||
auto builder_impl = absl::make_unique<InferenceBuilderImpl>(&environment_);
|
||||
RETURN_IF_ERROR(builder_impl->Initialize(options_, serialized_model));
|
||||
*builder = std::move(builder_impl);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::vector<uint8_t> GetSerializedBinaryCache() const final {
|
||||
std::vector<uint8_t> data;
|
||||
// Is there was a problem, data would be empty.
|
||||
|
@ -75,6 +75,20 @@ class InferenceEnvironment {
|
||||
public:
|
||||
virtual ~InferenceEnvironment() {}
|
||||
|
||||
// Converts GraphFloat32 into intermediate, device-specific representation.
|
||||
// This serialized_model specific for device and InferenceOptions.
|
||||
// serialized_model cannot be used with another device or InferenceOptions.
|
||||
// Loading serialized_model is much faster than loading GraphFloat32.
|
||||
// serialized_model must be used with appropriate NewInferenceBuilder
|
||||
// method (see below).
|
||||
virtual absl::Status BuildSerializedModel(
|
||||
const InferenceOptions& options, GraphFloat32 model,
|
||||
std::vector<uint8_t>* serialized_model) = 0;
|
||||
|
||||
virtual absl::Status NewInferenceBuilder(
|
||||
const std::vector<uint8_t>& serialized_model,
|
||||
std::unique_ptr<InferenceBuilder>* builder) = 0;
|
||||
|
||||
virtual absl::Status NewInferenceBuilder(
|
||||
const InferenceOptions& options, GraphFloat32 model,
|
||||
std::unique_ptr<InferenceBuilder>* builder) = 0;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/lite/delegates/gpu/api.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/api.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/environment.h"
|
||||
@ -85,8 +86,18 @@ void CompareCPUGPUResults(tflite::Interpreter* cpu,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status RunModelSampleWithInternalAPISerializedKernels(
|
||||
const std::string& model_name, const std::vector<uint8_t>& kernel_cache);
|
||||
|
||||
absl::Status RunModelSampleWithInternalAPISerialized(
|
||||
tflite::Interpreter* cpu, const std::vector<int64_t>& in_refs,
|
||||
const std::vector<int64_t>& out_refs,
|
||||
const std::vector<uint8_t>& kernel_cache,
|
||||
const std::vector<uint8_t>& serialized_model);
|
||||
|
||||
// Run Jet with OpenCL internal API and compares correctness with TFLite CPU
|
||||
absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) {
|
||||
absl::Status RunModelSampleWithInternalAPI(const std::string& model_name,
|
||||
std::vector<uint8_t>* kernel_cache) {
|
||||
auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str());
|
||||
|
||||
ops::builtin::BuiltinOpResolver op_resolver;
|
||||
@ -124,6 +135,7 @@ absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) {
|
||||
return absl::InternalError("Failed to Invoke CPU inference.");
|
||||
}
|
||||
|
||||
const auto start = std::chrono::high_resolution_clock::now();
|
||||
GraphFloat32 graph_cl;
|
||||
RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl));
|
||||
|
||||
@ -156,6 +168,7 @@ absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) {
|
||||
options.priority2 = InferencePriority::MIN_MEMORY_USAGE;
|
||||
options.priority3 = InferencePriority::MAX_PRECISION;
|
||||
options.usage = InferenceUsage::SUSTAINED_SPEED;
|
||||
|
||||
RETURN_IF_ERROR(
|
||||
inf_env->NewInferenceBuilder(options, std::move(graph_cl), &builder));
|
||||
|
||||
@ -176,6 +189,15 @@ absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) {
|
||||
// Builds runner.
|
||||
RETURN_IF_ERROR(builder->Build(&runner));
|
||||
|
||||
const auto end = std::chrono::high_resolution_clock::now();
|
||||
std::cout << "Initialization total time - " << (end - start).count() * 1e-6f
|
||||
<< "ms" << std::endl;
|
||||
|
||||
if (kernel_cache) {
|
||||
*kernel_cache = inf_env->GetSerializedBinaryCache();
|
||||
std::cout << "Kernel cache size - " << kernel_cache->size() << std::endl;
|
||||
}
|
||||
|
||||
// Sets the input/output object.
|
||||
for (int i = 0; i < in_refs.size(); ++i) {
|
||||
TfLiteTensor* tensor_ptr = cpu_inference->tensor(in_refs[i]);
|
||||
@ -198,6 +220,205 @@ absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RunModelSampleWithInternalAPISerializedKernels(
|
||||
const std::string& model_name, const std::vector<uint8_t>& kernel_cache) {
|
||||
auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str());
|
||||
|
||||
ops::builtin::BuiltinOpResolver op_resolver;
|
||||
InterpreterBuilder tfl_builder(*flatbuffer, op_resolver);
|
||||
|
||||
// CPU.
|
||||
std::unique_ptr<tflite::Interpreter> cpu_inference;
|
||||
tfl_builder(&cpu_inference);
|
||||
if (!cpu_inference) {
|
||||
return absl::InternalError("Failed to build CPU inference.");
|
||||
}
|
||||
auto status = cpu_inference->AllocateTensors();
|
||||
if (status != kTfLiteOk) {
|
||||
return absl::InternalError("Failed to AllocateTensors for CPU inference.");
|
||||
}
|
||||
for (int k = 0; k < cpu_inference->inputs().size(); ++k) {
|
||||
TfLiteTensor* tensor_ptr =
|
||||
cpu_inference->tensor(cpu_inference->inputs()[k]);
|
||||
if (tensor_ptr->type != kTfLiteFloat32) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Internal api supports only F32 input tensors");
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < cpu_inference->outputs().size(); ++k) {
|
||||
TfLiteTensor* tensor_ptr =
|
||||
cpu_inference->tensor(cpu_inference->outputs()[k]);
|
||||
if (tensor_ptr->type != kTfLiteFloat32) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Internal api supports only F32 output tensors");
|
||||
}
|
||||
}
|
||||
FillInputTensors(cpu_inference.get());
|
||||
status = cpu_inference->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
return absl::InternalError("Failed to Invoke CPU inference.");
|
||||
}
|
||||
|
||||
const auto start = std::chrono::high_resolution_clock::now();
|
||||
GraphFloat32 graph_cl;
|
||||
RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl));
|
||||
|
||||
auto inputs = graph_cl.inputs();
|
||||
auto outputs = graph_cl.outputs();
|
||||
std::vector<int64_t> in_refs(inputs.size());
|
||||
std::vector<int64_t> out_refs(outputs.size());
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
in_refs[i] = inputs[i]->tensor.ref;
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
out_refs[i] = outputs[i]->tensor.ref;
|
||||
}
|
||||
|
||||
Environment env;
|
||||
RETURN_IF_ERROR(CreateEnvironment(&env));
|
||||
|
||||
std::unique_ptr<InferenceEnvironment> inf_env;
|
||||
// Initializes environment.
|
||||
InferenceEnvironmentOptions env_options;
|
||||
env_options.device = env.device().id();
|
||||
env_options.context = env.context().context();
|
||||
env_options.command_queue = env.queue()->queue();
|
||||
env_options.serialized_binary_cache =
|
||||
absl::MakeSpan(kernel_cache.data(), kernel_cache.size());
|
||||
RETURN_IF_ERROR(NewInferenceEnvironment(env_options, &inf_env, nullptr));
|
||||
|
||||
InferenceOptions options;
|
||||
options.priority1 = InferencePriority::MIN_LATENCY;
|
||||
options.priority2 = InferencePriority::MIN_MEMORY_USAGE;
|
||||
options.priority3 = InferencePriority::MAX_PRECISION;
|
||||
options.usage = InferenceUsage::SUSTAINED_SPEED;
|
||||
|
||||
std::vector<uint8_t> serialized_model;
|
||||
RETURN_IF_ERROR(inf_env->BuildSerializedModel(options, std::move(graph_cl),
|
||||
&serialized_model));
|
||||
std::unique_ptr<InferenceBuilder> builder;
|
||||
RETURN_IF_ERROR(inf_env->NewInferenceBuilder(serialized_model, &builder));
|
||||
|
||||
// Sets input/output object def for builder_.
|
||||
ObjectDef obj_def;
|
||||
obj_def.data_type = DataType::FLOAT32;
|
||||
obj_def.data_layout = DataLayout::BHWC;
|
||||
obj_def.object_type = ObjectType::CPU_MEMORY;
|
||||
obj_def.user_provided = true;
|
||||
for (int i = 0; i < in_refs.size(); ++i) {
|
||||
RETURN_IF_ERROR(builder->SetInputObjectDef(i, obj_def));
|
||||
}
|
||||
for (int i = 0; i < out_refs.size(); ++i) {
|
||||
RETURN_IF_ERROR(builder->SetOutputObjectDef(i, obj_def));
|
||||
}
|
||||
|
||||
std::unique_ptr<::tflite::gpu::InferenceRunner> runner;
|
||||
// Builds runner.
|
||||
RETURN_IF_ERROR(builder->Build(&runner));
|
||||
|
||||
const auto end = std::chrono::high_resolution_clock::now();
|
||||
std::cout << "Initialization total time(with kernel cache) - "
|
||||
<< (end - start).count() * 1e-6f << "ms" << std::endl;
|
||||
|
||||
// Sets the input/output object.
|
||||
for (int i = 0; i < in_refs.size(); ++i) {
|
||||
TfLiteTensor* tensor_ptr = cpu_inference->tensor(in_refs[i]);
|
||||
RETURN_IF_ERROR(runner->SetInputObject(
|
||||
i, CpuMemory{tensor_ptr->data.data, tensor_ptr->bytes}));
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> output_tensors(out_refs.size());
|
||||
for (int i = 0; i < out_refs.size(); ++i) {
|
||||
TfLiteTensor* tensor_ptr = cpu_inference->tensor(out_refs[i]);
|
||||
output_tensors[i].resize(tensor_ptr->bytes / 4);
|
||||
RETURN_IF_ERROR(runner->SetOutputObject(
|
||||
i, CpuMemory{output_tensors[i].data(), tensor_ptr->bytes}));
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(runner->Run());
|
||||
|
||||
CompareCPUGPUResults(cpu_inference.get(), out_refs, output_tensors, 1e-4f);
|
||||
|
||||
RETURN_IF_ERROR(RunModelSampleWithInternalAPISerialized(
|
||||
cpu_inference.get(), in_refs, out_refs, kernel_cache, serialized_model));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RunModelSampleWithInternalAPISerialized(
|
||||
tflite::Interpreter* cpu, const std::vector<int64_t>& in_refs,
|
||||
const std::vector<int64_t>& out_refs,
|
||||
const std::vector<uint8_t>& kernel_cache,
|
||||
const std::vector<uint8_t>& serialized_model) {
|
||||
FillInputTensors(cpu);
|
||||
auto status = cpu->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
return absl::InternalError("Failed to Invoke CPU inference.");
|
||||
}
|
||||
|
||||
const auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
Environment env;
|
||||
RETURN_IF_ERROR(CreateEnvironment(&env));
|
||||
|
||||
std::unique_ptr<InferenceEnvironment> inf_env;
|
||||
// Initializes environment.
|
||||
InferenceEnvironmentOptions env_options;
|
||||
env_options.device = env.device().id();
|
||||
env_options.context = env.context().context();
|
||||
env_options.command_queue = env.queue()->queue();
|
||||
env_options.serialized_binary_cache =
|
||||
absl::MakeSpan(kernel_cache.data(), kernel_cache.size());
|
||||
RETURN_IF_ERROR(NewInferenceEnvironment(env_options, &inf_env, nullptr));
|
||||
|
||||
std::unique_ptr<InferenceBuilder> builder;
|
||||
RETURN_IF_ERROR(inf_env->NewInferenceBuilder(serialized_model, &builder));
|
||||
|
||||
// Sets input/output object def for builder_.
|
||||
ObjectDef obj_def;
|
||||
obj_def.data_type = DataType::FLOAT32;
|
||||
obj_def.data_layout = DataLayout::BHWC;
|
||||
obj_def.object_type = ObjectType::CPU_MEMORY;
|
||||
obj_def.user_provided = true;
|
||||
for (int i = 0; i < in_refs.size(); ++i) {
|
||||
RETURN_IF_ERROR(builder->SetInputObjectDef(i, obj_def));
|
||||
}
|
||||
for (int i = 0; i < out_refs.size(); ++i) {
|
||||
RETURN_IF_ERROR(builder->SetOutputObjectDef(i, obj_def));
|
||||
}
|
||||
|
||||
std::unique_ptr<::tflite::gpu::InferenceRunner> runner;
|
||||
// Builds runner.
|
||||
RETURN_IF_ERROR(builder->Build(&runner));
|
||||
|
||||
const auto end = std::chrono::high_resolution_clock::now();
|
||||
std::cout << "Serialized initialization total time - "
|
||||
<< (end - start).count() * 1e-6f << "ms" << std::endl;
|
||||
|
||||
// Sets the input/output object.
|
||||
for (int i = 0; i < in_refs.size(); ++i) {
|
||||
TfLiteTensor* tensor_ptr = cpu->tensor(in_refs[i]);
|
||||
RETURN_IF_ERROR(runner->SetInputObject(
|
||||
i, CpuMemory{tensor_ptr->data.data, tensor_ptr->bytes}));
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> output_tensors(out_refs.size());
|
||||
for (int i = 0; i < out_refs.size(); ++i) {
|
||||
TfLiteTensor* tensor_ptr = cpu->tensor(out_refs[i]);
|
||||
output_tensors[i].resize(tensor_ptr->bytes / 4);
|
||||
RETURN_IF_ERROR(runner->SetOutputObject(
|
||||
i, CpuMemory{output_tensors[i].data(), tensor_ptr->bytes}));
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(runner->Run());
|
||||
|
||||
std::cout << "Comparing results second time:" << std::endl;
|
||||
|
||||
CompareCPUGPUResults(cpu, out_refs, output_tensors, 1e-4f);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
@ -214,7 +435,15 @@ int main(int argc, char** argv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
auto run_status = tflite::gpu::cl::RunModelSampleWithInternalAPI(argv[1]);
|
||||
std::vector<uint8_t> kernel_cache;
|
||||
auto run_status =
|
||||
tflite::gpu::cl::RunModelSampleWithInternalAPI(argv[1], &kernel_cache);
|
||||
if (!run_status.ok()) {
|
||||
std::cerr << run_status.message();
|
||||
return -1;
|
||||
}
|
||||
run_status = tflite::gpu::cl::RunModelSampleWithInternalAPISerializedKernels(
|
||||
argv[1], kernel_cache);
|
||||
if (!run_status.ok()) {
|
||||
std::cerr << run_status.message();
|
||||
return -1;
|
||||
|
Loading…
Reference in New Issue
Block a user