Fix bug in dense shapes calculated from output size rather than node_def
PiperOrigin-RevId: 351519984 Change-Id: Ieceef64fefe9c34027bd869d06b43444fec34918
This commit is contained in:
parent
053587a591
commit
b1f3a52199
@ -754,12 +754,12 @@ TfLiteStatus PrepareParseExample(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const auto* serialized = GetInput(context, node, 0);
|
const auto* serialized = GetInput(context, node, 0);
|
||||||
const int batch_size =
|
const int batch_size =
|
||||||
serialized->dims->size > 0 ? serialized->dims->data[0] : 1;
|
serialized->dims->size > 0 ? serialized->dims->data[0] : 1;
|
||||||
|
const bool missing_shape_info = data->dense_shapes.empty();
|
||||||
for (int i = 0; i < data->dense_size; i++) {
|
for (int i = 0; i < data->dense_size; i++) {
|
||||||
TfLiteTensor* dense_key_tensor =
|
TfLiteTensor* dense_key_tensor =
|
||||||
GetOutput(context, node, data->sparse_size * 3 + i);
|
GetOutput(context, node, data->sparse_size * 3 + i);
|
||||||
TfLiteIntArray* output_size = TfLiteIntArrayCopy(dense_key_tensor->dims);
|
TfLiteIntArray* output_size = TfLiteIntArrayCopy(dense_key_tensor->dims);
|
||||||
if (data->dense_size > 0 && data->dense_shapes.empty()) {
|
if (missing_shape_info) {
|
||||||
RuntimeShape runtime_shape = GetTensorShape(dense_key_tensor);
|
RuntimeShape runtime_shape = GetTensorShape(dense_key_tensor);
|
||||||
data->dense_shapes.push_back(TfLiteToTfShape(output_size));
|
data->dense_shapes.push_back(TfLiteToTfShape(output_size));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user