diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index b1c6336e00e..6bd41f1033f 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -345,9 +345,15 @@ absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader, RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); *tensor_or_scalar = tensor.data[0]; } else { - Tensor tensor; - RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); - *tensor_or_scalar = std::move(tensor); + if (CheckIfLinearConvertible(constant_dims).ok()) { + Tensor tensor; + RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); + *tensor_or_scalar = std::move(tensor); + } else { + Tensor tensor; + RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); + *tensor_or_scalar = std::move(tensor); + } } } return absl::OkStatus(); @@ -363,12 +369,6 @@ class AddOperationParser : public TFLiteOperationParser { return absl::UnimplementedError("ADD requires two input tensors."); } // TODO(eignasheva): Add shapes check. - for (int i = 0; i < 2; i++) { - auto input = tflite::GetInput(context, tflite_node, i); - if (IsConstantTensor(input) && input->dims->size > 0) { - RETURN_IF_ERROR(CheckIfLinearConvertible(input->dims)); - } - } const TfLiteAddParams* tf_options; return RetrieveBuiltinData(tflite_node, &tf_options);