Merge pull request #45694 from linux-on-ibm-z:shape_signature_fix

PiperOrigin-RevId: 350065994
Change-Id: Iaa890b4d29b00f85bbaacf3af36c2ddfef2870b6
This commit is contained in:
TensorFlower Gardener 2021-01-04 19:32:14 -08:00
commit 561e7f8758
2 changed files with 12 additions and 10 deletions

View File

@ -111,11 +111,15 @@ class Subgraph {
inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization,
bool is_variable = false, const size_t rank_dims_signature = 0,
const int* dims_signature = nullptr) {
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
dims.data(), quantization, is_variable,
rank_dims_signature, dims_signature);
bool is_variable = false, const std::vector<int>& dims_signature = {}) {
if (dims_signature.empty()) {
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
dims.data(), quantization,
is_variable);
}
return SetTensorParametersReadWrite(
tensor_index, type, name, dims.size(), dims.data(), quantization,
is_variable, dims_signature.size(), dims_signature.data());
}
TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,

View File

@ -587,11 +587,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
status = kTfLiteError;
}
size_t dims_signature_rank = 0;
const int* dims_signature_data = nullptr;
std::vector<int> dims_signature = {};
if (tensor->shape_signature()) {
dims_signature_rank = tensor->shape_signature()->size();
dims_signature_data = tensor->shape_signature()->data();
dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature());
}
bool is_variable = tensor->is_variable();
@ -623,7 +621,7 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
} else {
if (subgraph->SetTensorParametersReadWrite(
i, type, get_name(tensor), dims, quantization, is_variable,
dims_signature_rank, dims_signature_data) != kTfLiteOk) {
dims_signature) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i);
status = kTfLiteError;