STT-tensorflow/tensorflow/lite/delegates/gpu/cl/program_cache.cc
Raman Sarokin 709215779f CompilerOptions moved to gpu/common/task.
PiperOrigin-RevId: 343102797
Change-Id: I5571e29484f093fbff7decd744deeea12eb4b123
2020-11-18 15:52:09 -08:00

151 lines
6.0 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
#include <cstdint>
#include <string>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
#include "tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h"
#include "tensorflow/lite/delegates/gpu/cl/util.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include <farmhash.h>
namespace tflite {
namespace gpu {
namespace cl {
ProgramCache::ProgramDescriptor::ProgramDescriptor(const std::string& code_text,
const std::string& options,
bool use_fingerprints)
: code(code_text),
compiler_options(options),
use_fingerprint(use_fingerprints) {
const uint64_t code_fingerprint = ::util::Fingerprint64(code);
const uint64_t options_fingerprint =
::util::Fingerprint64(compiler_options);
fingerprint = code_fingerprint + options_fingerprint;
}
ProgramCache::ProgramDescriptor::ProgramDescriptor(uint64_t fingerprints)
: fingerprint(fingerprints), use_fingerprint(true) {}
ProgramCache::ProgramCache(ProgramCache&& program_cache)
: use_fingerprints_(program_cache.use_fingerprints_),
programs_(std::move(program_cache.programs_)) {}
ProgramCache& ProgramCache::operator=(ProgramCache&& program_cache) {
if (this != &program_cache) {
use_fingerprints_ = program_cache.use_fingerprints_;
programs_ = std::move(program_cache.programs_);
}
return *this;
}
absl::Status ProgramCache::GetOrCreateCLKernel(
const std::string& code, const std::string& function_name,
const std::vector<CompilerOptions>& compiler_options,
const CLContext& context, const CLDevice& device, CLKernel* result) {
const std::string options =
CompilerOptionsToString(device.GetInfo(), compiler_options);
ProgramDescriptor desc{code, options, use_fingerprints_};
auto it = programs_.find(desc);
if (it != programs_.end()) {
return result->CreateFromProgram(it->second, function_name);
}
CLProgram program;
RETURN_IF_ERROR(CreateCLProgram(code, options, context, device, &program));
RETURN_IF_ERROR(result->CreateFromProgram(program, function_name));
programs_.insert(std::make_pair(std::move(desc), std::move(program)));
return absl::OkStatus();
}
absl::Status ProgramCache::GetOrCreateCLKernel(const std::string& code,
const std::string& function_name,
const CLContext& context,
const CLDevice& device,
CLKernel* result) {
return GetOrCreateCLKernel(code, function_name, {}, context, device, result);
}
absl::Status ProgramCache::AddSerializedCache(
const CLContext& context, const CLDevice& device,
absl::Span<const uint8_t> serialized_cache) {
flatbuffers::Verifier verifier(serialized_cache.data(),
serialized_cache.size());
if (!data::VerifyCompiledCacheBuffer(verifier)) {
return absl::InvalidArgumentError("Serialized model is corrupted.");
}
auto model = data::GetCompiledCache(serialized_cache.data());
std::string platform_version(model->driver_version()->c_str(),
model->driver_version()->size());
if (device.GetPlatformVersion() != platform_version) {
return absl::InvalidArgumentError(
"OpenCL driver changed, cache invalid, should be regenerated");
}
use_fingerprints_ = true;
for (auto serialized_program : *model->programs()) {
ProgramDescriptor desc(serialized_program->fingerprint());
CLProgram program;
RETURN_IF_ERROR(CreateCLProgramFromBinary(
context, device,
absl::MakeSpan(serialized_program->binary()->data(),
serialized_program->binary()->size()),
&program));
auto it = programs_.find(desc);
if (it == programs_.end()) {
programs_.insert(std::make_pair(std::move(desc), std::move(program)));
}
}
return absl::OkStatus();
}
absl::Status ProgramCache::GetSerializedCache(
const CLDevice& device, std::vector<uint8_t>* serialized_cache) const {
::flatbuffers::FlatBufferBuilder builder;
std::vector<flatbuffers::Offset<data::Program>> serialized_programs;
for (auto& program : programs_) {
std::vector<uint8_t> binary;
RETURN_IF_ERROR(program.second.GetBinary(&binary));
auto binary_offset = builder.CreateVector(binary);
data::ProgramBuilder program_builder(builder);
program_builder.add_fingerprint(program.first.fingerprint);
program_builder.add_binary(binary_offset);
serialized_programs.push_back(program_builder.Finish());
}
auto driver_version = builder.CreateString(device.GetPlatformVersion());
auto programs_s = builder.CreateVector(serialized_programs);
data::CompiledCacheBuilder cache_builder(builder);
cache_builder.add_driver_version(driver_version);
cache_builder.add_programs(programs_s);
data::FinishCompiledCacheBuffer(builder, cache_builder.Finish());
size_t next_element = serialized_cache->size();
serialized_cache->resize(serialized_cache->size() + builder.GetSize());
std::memcpy(&(*serialized_cache)[next_element], builder.GetBufferPointer(),
builder.GetSize());
return absl::OkStatus();
}
} // namespace cl
} // namespace gpu
} // namespace tflite