diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index dc671a47691..01f94c94888 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1271,7 +1271,7 @@ class MulOperationParser : public TFLiteOperationParser { GraphFloat32* graph, ObjectReader* reader) { RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor)); MultiplyAttributes attr; - if (constant_dims->size <= 0) { + if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) { Tensor tensor; RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); attr.param = tensor.data[0];