Added support of HWC constant tensor in Add operation.
PiperOrigin-RevId: 314586686 Change-Id: Ib30c36c906fe5e549a26f89e8385bdb21e538038
This commit is contained in:
parent
59092da68c
commit
ec80944f58
|
@ -345,9 +345,15 @@ absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
|
|||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
*tensor_or_scalar = tensor.data[0];
|
||||
} else {
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
*tensor_or_scalar = std::move(tensor);
|
||||
if (CheckIfLinearConvertible(constant_dims).ok()) {
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
*tensor_or_scalar = std::move(tensor);
|
||||
} else {
|
||||
Tensor<HWC, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
*tensor_or_scalar = std::move(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
@ -363,12 +369,6 @@ class AddOperationParser : public TFLiteOperationParser {
|
|||
return absl::UnimplementedError("ADD requires two input tensors.");
|
||||
}
|
||||
// TODO(eignasheva): Add shapes check.
|
||||
for (int i = 0; i < 2; i++) {
|
||||
auto input = tflite::GetInput(context, tflite_node, i);
|
||||
if (IsConstantTensor(input) && input->dims->size > 0) {
|
||||
RETURN_IF_ERROR(CheckIfLinearConvertible(input->dims));
|
||||
}
|
||||
}
|
||||
|
||||
const TfLiteAddParams* tf_options;
|
||||
return RetrieveBuiltinData(tflite_node, &tf_options);
|
||||
|
|
Loading…
Reference in New Issue