Added support of HWC constant tensor in Add operation.

PiperOrigin-RevId: 314586686
Change-Id: Ib30c36c906fe5e549a26f89e8385bdb21e538038
This commit is contained in:
Raman Sarokin 2020-06-03 12:53:32 -07:00 committed by TensorFlower Gardener
parent 59092da68c
commit ec80944f58
1 changed files with 9 additions and 9 deletions

View File

@ -345,9 +345,15 @@ absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = tensor.data[0]; *tensor_or_scalar = tensor.data[0];
} else { } else {
Tensor<Linear, DataType::FLOAT32> tensor; if (CheckIfLinearConvertible(constant_dims).ok()) {
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); Tensor<Linear, DataType::FLOAT32> tensor;
*tensor_or_scalar = std::move(tensor); RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = std::move(tensor);
} else {
Tensor<HWC, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
*tensor_or_scalar = std::move(tensor);
}
} }
} }
return absl::OkStatus(); return absl::OkStatus();
@ -363,12 +369,6 @@ class AddOperationParser : public TFLiteOperationParser {
return absl::UnimplementedError("ADD requires two input tensors."); return absl::UnimplementedError("ADD requires two input tensors.");
} }
// TODO(eignasheva): Add shapes check. // TODO(eignasheva): Add shapes check.
for (int i = 0; i < 2; i++) {
auto input = tflite::GetInput(context, tflite_node, i);
if (IsConstantTensor(input) && input->dims->size > 0) {
RETURN_IF_ERROR(CheckIfLinearConvertible(input->dims));
}
}
const TfLiteAddParams* tf_options; const TfLiteAddParams* tf_options;
return RetrieveBuiltinData(tflite_node, &tf_options); return RetrieveBuiltinData(tflite_node, &tf_options);