Using input shape when deduce axes for reduction.

PiperOrigin-RevId: 346345936
Change-Id: I934baf9b6a2e79eac22d3131af7a0581ff4167bd
This commit is contained in:
Raman Sarokin 2020-12-08 09:43:10 -08:00 committed by TensorFlower Gardener
parent 3c37f87c45
commit 8206491e82

View File

@ -1534,10 +1534,10 @@ class ReduceOperationParser : public TFLiteOperationParser {
ReduceAttributes attr;
Tensor<Linear, DataType::INT32> axes;
RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
const TfLiteTensor* output = reader->GetOutputTensor(0);
const TfLiteTensor* input = reader->GetInputTensor(0);
for (int i = 0; i < axes.data.size(); i++) {
Axis axis;
RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis));
RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[i], &axis));
attr.dims.insert(axis);
}
node->operation.attributes = attr;
@ -2615,10 +2615,10 @@ class MeanOperationParser : public TFLiteOperationParser {
MeanAttributes attr;
Tensor<Linear, DataType::INT32> axes;
RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
const TfLiteTensor* output = reader->GetOutputTensor(0);
const TfLiteTensor* input = reader->GetInputTensor(0);
for (int i = 0; i < axes.data.size(); i++) {
Axis axis;
RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis));
RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[i], &axis));
attr.dims.insert(axis);
}
node->operation.attributes = attr;