Fix bug on referencing invalid reference to a tensor.
Adding tensors to the TfLiteContext will give TfLiteContext.tensors a new address, so the old references should be updated with new one. PiperOrigin-RevId: 319188033 Change-Id: I1538d6260236f7cf5a710621d6af330a8639f443
This commit is contained in:
parent
6fc54250a5
commit
fc49cbb2ad
@ -37,7 +37,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
|
if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
|
||||||
const TfLiteTensor& tflite_tensor = context->tensors[tensor_idx];
|
TfLiteTensor& tflite_tensor = context->tensors[tensor_idx];
|
||||||
if (tflite::IsConstantTensor(&tflite_tensor)) {
|
if (tflite::IsConstantTensor(&tflite_tensor)) {
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
return absl::InvalidArgumentError(absl::StrCat(
|
||||||
"ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
|
"ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
|
||||||
@ -58,6 +58,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
|||||||
&fp_tensor_index) != kTfLiteOk) {
|
&fp_tensor_index) != kTfLiteOk) {
|
||||||
return absl::InternalError("Could not add new tensor to graph");
|
return absl::InternalError("Could not add new tensor to graph");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remember this tensor for later.
|
// Remember this tensor for later.
|
||||||
(*quant_conversion_map)[fp_tensor_index] = tensor_idx;
|
(*quant_conversion_map)[fp_tensor_index] = tensor_idx;
|
||||||
(*quant_conversion_map)[tensor_idx] = fp_tensor_index;
|
(*quant_conversion_map)[tensor_idx] = fp_tensor_index;
|
||||||
@ -67,6 +68,9 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
|||||||
ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
|
ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
|
||||||
value->tensor.ref = fp_tensor_index;
|
value->tensor.ref = fp_tensor_index;
|
||||||
value->quant_params.emplace();
|
value->quant_params.emplace();
|
||||||
|
// tflite_tensor from the outer scope is invalidated due to calling
|
||||||
|
// CreateNewTensorWithDifferentType
|
||||||
|
tflite_tensor = context->tensors[tensor_idx];
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
PopulateQuantParams(tflite_tensor, &value->quant_params.value()));
|
PopulateQuantParams(tflite_tensor, &value->quant_params.value()));
|
||||||
(*tensor_to_value)[fp_tensor_index] = value;
|
(*tensor_to_value)[fp_tensor_index] = value;
|
||||||
|
@ -29,8 +29,8 @@ TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
|
|||||||
TfLiteType new_type,
|
TfLiteType new_type,
|
||||||
TfLiteTensor** new_tensor,
|
TfLiteTensor** new_tensor,
|
||||||
int* new_tensor_index) {
|
int* new_tensor_index) {
|
||||||
const TfLiteTensor& original_tensor = context->tensors[original_tensor_index];
|
|
||||||
TF_LITE_ENSURE_STATUS(context->AddTensors(context, 1, new_tensor_index));
|
TF_LITE_ENSURE_STATUS(context->AddTensors(context, 1, new_tensor_index));
|
||||||
|
const TfLiteTensor& original_tensor = context->tensors[original_tensor_index];
|
||||||
*new_tensor = &context->tensors[*new_tensor_index];
|
*new_tensor = &context->tensors[*new_tensor_index];
|
||||||
(*new_tensor)->type = new_type;
|
(*new_tensor)->type = new_type;
|
||||||
(*new_tensor)->allocation_type = kTfLiteArenaRw;
|
(*new_tensor)->allocation_type = kTfLiteArenaRw;
|
||||||
|
@ -33,7 +33,8 @@ namespace tflite {
|
|||||||
namespace delegates {
|
namespace delegates {
|
||||||
|
|
||||||
// Creates a new Read/Write tensor having the same shape as the original, but
|
// Creates a new Read/Write tensor having the same shape as the original, but
|
||||||
// with a different type.
|
// with a different type. Note that this might void existing references to
|
||||||
|
// tensors.
|
||||||
TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
|
TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
|
||||||
const int original_tensor_index,
|
const int original_tensor_index,
|
||||||
TfLiteType new_type,
|
TfLiteType new_type,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user