Added support of HWC constant tensor in Add operation.

PiperOrigin-RevId: 314586686
Change-Id: Ib30c36c906fe5e549a26f89e8385bdb21e538038
This commit is contained in:
Raman Sarokin 2020-06-03 12:53:32 -07:00 committed by TensorFlower Gardener
parent 59092da68c
commit ec80944f58
1 changed files with 9 additions and 9 deletions

View File

@ -345,9 +345,15 @@ absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = tensor.data[0];
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = std::move(tensor);
if (CheckIfLinearConvertible(constant_dims).ok()) {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = std::move(tensor);
} else {
Tensor<HWC, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = std::move(tensor);
}
}
}
return absl::OkStatus();
@ -363,12 +369,6 @@ class AddOperationParser : public TFLiteOperationParser {
return absl::UnimplementedError("ADD requires two input tensors.");
}
// TODO(eignasheva): Add shapes check.
for (int i = 0; i < 2; i++) {
auto input = tflite::GetInput(context, tflite_node, i);
if (IsConstantTensor(input) && input->dims->size > 0) {
RETURN_IF_ERROR(CheckIfLinearConvertible(input->dims));
}
}
const TfLiteAddParams* tf_options;
return RetrieveBuiltinData(tflite_node, &tf_options);