Fix for scalar tensors with BytesRequiredForTensor.

PiperOrigin-RevId: 313419954
Change-Id: Id6ce0fec1a1640d332bd55694ed23dc6b9da58da
This commit is contained in:
A. Unique TensorFlower 2020-05-27 10:48:43 -07:00 committed by TensorFlower Gardener
parent 1d0dfbde01
commit 5e6cb6e324

View File

@ -83,9 +83,13 @@ TfLiteStatus BytesRequiredForTensor(const tflite::Tensor& flatbuffer_tensor,
size_t* bytes, size_t* type_size, size_t* bytes, size_t* type_size,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
int element_count = 1; int element_count = 1;
// If flatbuffer_tensor.shape == nullptr, then flatbuffer_tensor is a scalar
// so has 1 element.
if (flatbuffer_tensor.shape() != nullptr) {
for (size_t n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) { for (size_t n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
element_count *= flatbuffer_tensor.shape()->Get(n); element_count *= flatbuffer_tensor.shape()->Get(n);
} }
}
TfLiteType tf_lite_type; TfLiteType tf_lite_type;
TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(), TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(),