[tfls.codegen] Fix potential nullptr seg fault.
PiperOrigin-RevId: 306544979 Change-Id: I7768511d0ca5ba226d6909852fc902cd282aadb4
This commit is contained in:
parent
4b2cb67756
commit
d3c76ae3af
tensorflow/lite/experimental/support/codegen
@ -112,20 +112,24 @@ TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
|
||||
tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]);
|
||||
tensor_info.normalization_unit =
|
||||
FindNormalizationUnit(metadata, tensor_identifier, err);
|
||||
if (metadata->content()->content_properties_type() ==
|
||||
ContentProperties_ImageProperties) {
|
||||
if (metadata->content()
|
||||
->content_properties_as_ImageProperties()
|
||||
->color_space() == ColorSpaceType_RGB) {
|
||||
tensor_info.content_type = "image";
|
||||
tensor_info.wrapper_type = "TensorImage";
|
||||
tensor_info.processor_type = "ImageProcessor";
|
||||
return tensor_info;
|
||||
} else {
|
||||
err->Warning(
|
||||
"Found Non-RGB image on tensor (%s). Codegen currently does not "
|
||||
"support it, and regard it as a plain numeric tensor.",
|
||||
tensor_identifier.c_str());
|
||||
if (metadata->content() != nullptr &&
|
||||
metadata->content()->content_properties() != nullptr) {
|
||||
// Enter tensor wrapper type inferring
|
||||
if (metadata->content()->content_properties_type() ==
|
||||
ContentProperties_ImageProperties) {
|
||||
if (metadata->content()
|
||||
->content_properties_as_ImageProperties()
|
||||
->color_space() == ColorSpaceType_RGB) {
|
||||
tensor_info.content_type = "image";
|
||||
tensor_info.wrapper_type = "TensorImage";
|
||||
tensor_info.processor_type = "ImageProcessor";
|
||||
return tensor_info;
|
||||
} else {
|
||||
err->Warning(
|
||||
"Found Non-RGB image on tensor (%s). Codegen currently does not "
|
||||
"support it, and regard it as a plain numeric tensor.",
|
||||
tensor_identifier.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
tensor_info.content_type = "tensor";
|
||||
@ -154,12 +158,12 @@ ModelInfo CreateModelInfo(const ModelMetadata* metadata,
|
||||
graph->input_tensor_metadata(), graph->output_tensor_metadata());
|
||||
std::vector<std::string> input_tensor_names = std::move(names.first);
|
||||
std::vector<std::string> output_tensor_names = std::move(names.second);
|
||||
for (int i = 0; i < graph->input_tensor_metadata()->size(); i++) {
|
||||
for (int i = 0; i < input_tensor_names.size(); i++) {
|
||||
model_info.inputs.push_back(
|
||||
CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
|
||||
input_tensor_names[i], true, i, err));
|
||||
}
|
||||
for (int i = 0; i < graph->output_tensor_metadata()->size(); i++) {
|
||||
for (int i = 0; i < output_tensor_names.size(); i++) {
|
||||
model_info.outputs.push_back(
|
||||
CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
|
||||
output_tensor_names[i], false, i, err));
|
||||
@ -945,6 +949,11 @@ GenerationResult AndroidJavaGenerator::Generate(
|
||||
const Model* model, const std::string& package_name,
|
||||
const std::string& model_class_name, const std::string& model_asset_path) {
|
||||
GenerationResult result;
|
||||
if (model == nullptr) {
|
||||
err_.Error(
|
||||
"Cannot read model from the buffer. Codegen will generate nothing.");
|
||||
return result;
|
||||
}
|
||||
const ModelMetadata* metadata = GetMetadataFromModel(model);
|
||||
if (metadata == nullptr) {
|
||||
err_.Error(
|
||||
|
@ -24,14 +24,22 @@ namespace codegen {
|
||||
|
||||
constexpr char BUFFER_KEY[] = "TFLITE_METADATA";
|
||||
const ModelMetadata* GetMetadataFromModel(const Model* model) {
|
||||
if (model->metadata() == nullptr) {
|
||||
if (model == nullptr || model->metadata() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
for (auto i = 0; i < model->metadata()->size(); i++) {
|
||||
if (model->metadata()->Get(i)->name()->str() == BUFFER_KEY) {
|
||||
const auto* name = model->metadata()->Get(i)->name();
|
||||
if (name != nullptr && name->str() == BUFFER_KEY) {
|
||||
const auto buffer_index = model->metadata()->Get(i)->buffer();
|
||||
const auto* buffer = model->buffers()->Get(buffer_index)->data()->data();
|
||||
return GetModelMetadata(buffer);
|
||||
if (model->buffers() == nullptr ||
|
||||
model->buffers()->size() <= buffer_index) {
|
||||
continue;
|
||||
}
|
||||
const auto* buffer_vec = model->buffers()->Get(buffer_index)->data();
|
||||
if (buffer_vec == nullptr || buffer_vec->data() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
return GetModelMetadata(buffer_vec->data());
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
Loading…
Reference in New Issue
Block a user