Fix bug in dense shapes calculated from output size rather than node_def

PiperOrigin-RevId: 351519984
Change-Id: Ieceef64fefe9c34027bd869d06b43444fec34918
This commit is contained in:
David Rim 2021-01-12 23:05:51 -08:00 committed by TensorFlower Gardener
parent 053587a591
commit b1f3a52199

View File

@ -754,12 +754,12 @@ TfLiteStatus PrepareParseExample(TfLiteContext* context, TfLiteNode* node) {
const auto* serialized = GetInput(context, node, 0);
const int batch_size =
serialized->dims->size > 0 ? serialized->dims->data[0] : 1;
const bool missing_shape_info = data->dense_shapes.empty();
for (int i = 0; i < data->dense_size; i++) {
TfLiteTensor* dense_key_tensor =
GetOutput(context, node, data->sparse_size * 3 + i);
TfLiteIntArray* output_size = TfLiteIntArrayCopy(dense_key_tensor->dims);
if (data->dense_size > 0 && data->dense_shapes.empty()) {
if (missing_shape_info) {
RuntimeShape runtime_shape = GetTensorShape(dense_key_tensor);
data->dense_shapes.push_back(TfLiteToTfShape(output_size));
}