Using input shape when deduce axes for reduction.
PiperOrigin-RevId: 346345936 Change-Id: I934baf9b6a2e79eac22d3131af7a0581ff4167bd
This commit is contained in:
parent
3c37f87c45
commit
8206491e82
@ -1534,10 +1534,10 @@ class ReduceOperationParser : public TFLiteOperationParser {
|
||||
ReduceAttributes attr;
|
||||
Tensor<Linear, DataType::INT32> axes;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
|
||||
const TfLiteTensor* output = reader->GetOutputTensor(0);
|
||||
const TfLiteTensor* input = reader->GetInputTensor(0);
|
||||
for (int i = 0; i < axes.data.size(); i++) {
|
||||
Axis axis;
|
||||
RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis));
|
||||
RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[i], &axis));
|
||||
attr.dims.insert(axis);
|
||||
}
|
||||
node->operation.attributes = attr;
|
||||
@ -2615,10 +2615,10 @@ class MeanOperationParser : public TFLiteOperationParser {
|
||||
MeanAttributes attr;
|
||||
Tensor<Linear, DataType::INT32> axes;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &axes));
|
||||
const TfLiteTensor* output = reader->GetOutputTensor(0);
|
||||
const TfLiteTensor* input = reader->GetInputTensor(0);
|
||||
for (int i = 0; i < axes.data.size(); i++) {
|
||||
Axis axis;
|
||||
RETURN_IF_ERROR(ExtractAxisFromIndex(*output, axes.data[i], &axis));
|
||||
RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes.data[i], &axis));
|
||||
attr.dims.insert(axis);
|
||||
}
|
||||
node->operation.attributes = attr;
|
||||
|
Loading…
Reference in New Issue
Block a user